Text-to-Image
Diffusers
Safetensors
LibreFluxIPAdapterPipeline
neuralvfx commited on
Commit
7feb367
·
verified ·
1 Parent(s): 60159a1

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1,18 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ examples/side_by_side.png filter=lfs diff=lfs merge=lfs -text
2
+ examples/libre_flux_control_image.png filter=lfs diff=lfs merge=lfs -text
3
+ examples/side_by_side_a.png filter=lfs diff=lfs merge=lfs -text
4
+ examples/side_by_side_b.png filter=lfs diff=lfs merge=lfs -text
5
+ examples/david.jpg filter=lfs diff=lfs merge=lfs -text
6
+ examples/matrix_edge.png filter=lfs diff=lfs merge=lfs -text
7
+ examples/mona.jpg filter=lfs diff=lfs merge=lfs -text
8
+ image_encoder/model.safetensors filter=lfs diff=lfs merge=lfs -text
9
+ image_encoder/spiece.model filter=lfs diff=lfs merge=lfs -text
10
+ ip_adapter.pt filter=lfs diff=lfs merge=lfs -text
11
+ text_encoder/model.safetensors filter=lfs diff=lfs merge=lfs -text
12
+ text_encoder_2/model-00001-of-00002.safetensors filter=lfs diff=lfs merge=lfs -text
13
+ text_encoder_2/model-00002-of-00002.safetensors filter=lfs diff=lfs merge=lfs -text
14
+ tokenizer_2/spiece.model filter=lfs diff=lfs merge=lfs -text
15
+ transformer/diffusion_pytorch_model-00001-of-00003.safetensors filter=lfs diff=lfs merge=lfs -text
16
+ transformer/diffusion_pytorch_model-00002-of-00003.safetensors filter=lfs diff=lfs merge=lfs -text
17
+ transformer/diffusion_pytorch_model-00003-of-00003.safetensors filter=lfs diff=lfs merge=lfs -text
18
+ vae/diffusion_pytorch_model.safetensors filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/david.jpg ADDED

Git LFS Details

  • SHA256: 0e6a82b4100d889365ab7e7856bb1f5ebe5b3d8e2dcad6b1f45aa80885e3f49e
  • Pointer size: 131 Bytes
  • Size of remote file: 150 kB
examples/matrix_edge.png ADDED

Git LFS Details

  • SHA256: d69193c1f781ff8024300bc114fd686469f9f7c54e8a0d2a36d82e8a4bb110a2
  • Pointer size: 132 Bytes
  • Size of remote file: 2.01 MB
examples/mona.jpg ADDED

Git LFS Details

  • SHA256: ec04c3335ddf7d9122e63f609885d95b9518c456c11075ab155ad2e8b8c85a00
  • Pointer size: 131 Bytes
  • Size of remote file: 270 kB
flux_ip_adapter.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import chain
2
+ import torch
3
+ from torch import nn
4
+ from diffusers.models.attention_processor import (
5
+ Attention,
6
+ AttentionProcessor,
7
+ )
8
+
9
+ from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
10
+ import torch.nn.functional as F
11
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
12
+ from diffusers.models.attention_processor import Attention
13
+ import inspect
14
+ from functools import partial
15
+ from diffusers.models.normalization import RMSNorm
16
+ from typing import Any, Dict, List, Optional, Union
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+
21
+ class IPFluxAttnProcessor2_0(nn.Module):
22
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
23
+
24
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, num_heads=0):
25
+ super().__init__()
26
+
27
+ self.hidden_size = hidden_size
28
+ self.cross_attention_dim = cross_attention_dim
29
+ self.scale = scale
30
+ self.num_tokens = num_tokens
31
+
32
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
33
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
34
+
35
+ self.norm_added_k = RMSNorm(128, eps=1e-5, elementwise_affine=False)
36
+
37
+ def __call__(
38
+ self,
39
+ attn,
40
+ hidden_states: torch.FloatTensor,
41
+ encoder_hidden_states: torch.FloatTensor = None,
42
+ ip_encoder_hidden_states: torch.FloatTensor = None,
43
+ attention_mask: Optional[torch.FloatTensor] = None,
44
+ image_rotary_emb: Optional[torch.Tensor] = None,
45
+ layer_scale: Optional[torch.Tensor] = None,
46
+ ) -> torch.FloatTensor:
47
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
48
+
49
+ ip_hidden_states = ip_encoder_hidden_states
50
+
51
+ # `sample` projections.
52
+ query = attn.to_q(hidden_states)
53
+ key = attn.to_k(hidden_states)
54
+ value = attn.to_v(hidden_states)
55
+
56
+ inner_dim = key.shape[-1]
57
+ head_dim = inner_dim // attn.heads
58
+
59
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
60
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
61
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
62
+
63
+ if attn.norm_q is not None:
64
+ query = attn.norm_q(query)
65
+ if attn.norm_k is not None:
66
+ key = attn.norm_k(key)
67
+
68
+ # handle IP attention FIRST
69
+
70
+
71
+ # for ip-adapter
72
+ if ip_hidden_states != None:
73
+ ip_key = self.to_k_ip(ip_hidden_states)
74
+ ip_value = self.to_v_ip(ip_hidden_states)
75
+
76
+ # reshaping to match query shape
77
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
78
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
79
+
80
+ ip_key = self.norm_added_k(ip_key)
81
+
82
+
83
+ # Using flux stype attention here
84
+ ip_hidden_states = F.scaled_dot_product_attention(
85
+ query,
86
+ ip_key,
87
+ ip_value,
88
+ dropout_p=0.0,
89
+ is_causal=False,
90
+ attn_mask=None,
91
+ )
92
+
93
+ # reshaping ip_hidden_states in the same way as hidden_states
94
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
95
+ batch_size, -1, attn.heads * head_dim
96
+ )
97
+
98
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
99
+ if encoder_hidden_states is not None:
100
+ # `context` projections.
101
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
102
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
103
+
104
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
105
+
106
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
107
+ batch_size, -1, attn.heads, head_dim
108
+ ).transpose(1, 2)
109
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
110
+ batch_size, -1, attn.heads, head_dim
111
+ ).transpose(1, 2)
112
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
113
+ batch_size, -1, attn.heads, head_dim
114
+ ).transpose(1, 2)
115
+
116
+ if attn.norm_added_q is not None:
117
+ encoder_hidden_states_query_proj = attn.norm_added_q(
118
+ encoder_hidden_states_query_proj
119
+ )
120
+ if attn.norm_added_k is not None:
121
+ encoder_hidden_states_key_proj = attn.norm_added_k(
122
+ encoder_hidden_states_key_proj
123
+ )
124
+
125
+ # attention
126
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
127
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
128
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
129
+
130
+ if image_rotary_emb is not None:
131
+ from diffusers.models.embeddings import apply_rotary_emb
132
+ query = apply_rotary_emb(query, image_rotary_emb)
133
+
134
+ key = apply_rotary_emb(key, image_rotary_emb)
135
+
136
+ if attention_mask is not None:
137
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
138
+ attention_mask = (attention_mask > 0).bool()
139
+ attention_mask = attention_mask.to(
140
+ device=hidden_states.device, dtype=query.dtype
141
+ )
142
+ original_hidden_states = hidden_states
143
+
144
+ hidden_states = F.scaled_dot_product_attention(
145
+ query,
146
+ key,
147
+ value,
148
+ dropout_p=0.0,
149
+ is_causal=False,
150
+ attn_mask=attention_mask,
151
+ )
152
+
153
+ hidden_states = hidden_states.transpose(1, 2).reshape(
154
+ batch_size, -1, attn.heads * head_dim
155
+ )
156
+ hidden_states = hidden_states.to(query.dtype)
157
+
158
+
159
+ layer_scale = layer_scale.view(-1, 1, 1)
160
+
161
+ if encoder_hidden_states is not None:
162
+
163
+ encoder_hidden_states, hidden_states = (
164
+ hidden_states[:, : encoder_hidden_states.shape[1]],
165
+ hidden_states[:, encoder_hidden_states.shape[1] :],
166
+ )
167
+
168
+ # Final injection of ip addapter hidden_states
169
+ if ip_hidden_states != None:
170
+ hidden_states = hidden_states + (self.scale * layer_scale) * ip_hidden_states
171
+
172
+ # linear proj
173
+ hidden_states = attn.to_out[0](hidden_states)
174
+ # dropout
175
+ hidden_states = attn.to_out[1](hidden_states)
176
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
177
+
178
+ return hidden_states, encoder_hidden_states
179
+
180
+ else:
181
+
182
+ # Final injection of ip addapter hidden_states
183
+ if ip_hidden_states != None:
184
+ hidden_states = hidden_states + (self.scale * layer_scale) * ip_hidden_states
185
+
186
+ if attn.to_out is not None:
187
+ hidden_states = attn.to_out[0](hidden_states)
188
+ hidden_states = attn.to_out[1](hidden_states)
189
+
190
+ return hidden_states
191
+
192
+
193
+ class ImageProjModel(nn.Module):
194
+ def __init__(self, clip_dim=768, cross_attention_dim=4096, num_tokens=16):
195
+ super().__init__()
196
+
197
+ self.num_tokens = num_tokens
198
+ self.cross_attention_dim = cross_attention_dim
199
+ self.clip_dim = clip_dim
200
+
201
+ self.proj = torch.nn.Sequential(
202
+ torch.nn.Linear(clip_dim,clip_dim*2),
203
+ torch.nn.GELU(),
204
+ torch.nn.Linear(clip_dim*2, cross_attention_dim*num_tokens),
205
+ )
206
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
207
+
208
+ def forward(self,input):
209
+
210
+ raw_proj = self.proj(input)
211
+ reshaped_proj = raw_proj.reshape(input.shape[0],self.num_tokens,self.cross_attention_dim)
212
+ reshaped_proj = self.norm( reshaped_proj )
213
+
214
+ return reshaped_proj
215
+
216
+
217
+ class LibreFluxIPAdapter(nn.Module):
218
+ def __init__(self, transformer, image_proj_model, checkpoint=None):
219
+ super().__init__()
220
+ self.transformer = transformer
221
+ self.image_proj_model = image_proj_model
222
+
223
+ # Using startswith uses only double transformer blocks, and skips the single transformer blocks
224
+ self.culled_transformer_blocks = {}
225
+ for name, module in self.transformer.named_modules():
226
+ if isinstance(module, Attention):
227
+ if name.startswith('transformer_blocks') or name.startswith('single_transformer_blocks'):
228
+ #print (f"Using Transformer: {name}")
229
+ self.culled_transformer_blocks[name] = module
230
+ #else:
231
+ #print (f"Ignoring Transformer: {name}")
232
+ # Apply the adapter to the culled blocks
233
+ self.wrap_attention_blocks()
234
+
235
+ if checkpoint:
236
+ self.load_from_checkpoint(checkpoint)
237
+
238
+ def wrap_attention_blocks(self,scale=1.0, num_tokens=16):
239
+ """ Inject the IP-Adapter modules into the Transformer model """
240
+ sample_attn = self.transformer.transformer_blocks[0].attn
241
+
242
+ hidden_size = sample_attn.inner_dim
243
+ cross_attention_dim = sample_attn.cross_attention_dim
244
+ num_heads = sample_attn.heads
245
+ scale = 1.0
246
+ num_tokens = 16
247
+
248
+ processor_list = []
249
+ for name in self.culled_transformer_blocks:
250
+ module = self.culled_transformer_blocks[name]
251
+ module.processor = IPFluxAttnProcessor2_0(
252
+ hidden_size= hidden_size,
253
+ cross_attention_dim=4096,
254
+ num_heads=num_heads,
255
+ scale=1.0,
256
+ num_tokens=16,
257
+ )
258
+ processor_list.append(module.processor )
259
+ lay_count = len(processor_list)
260
+ print (f"Added Attention IP Wrapper to {lay_count} layers")
261
+
262
+ # Store adapters as a module list for saving/loading
263
+ self.adapter_modules = torch.nn.ModuleList(processor_list)
264
+
265
+ def parameters(self):
266
+ """ Easy way to return all params """
267
+ # Apply adapter
268
+ adapter_param_list = []
269
+ for name in self.culled_transformer_blocks:
270
+ module = self.culled_transformer_blocks[name]
271
+ adapter_param_list.append(module.processor.parameters())
272
+
273
+ all_params = chain(*adapter_param_list,self.image_proj_model.parameters())
274
+ return all_params
275
+
276
+ def forward(self, ref_image, *args, layer_scale= torch.Tensor([1.0]), **kwargs):
277
+ """ Run projection and run forward """
278
+ mod_dtype = next(self.image_proj_model.parameters()).dtype
279
+ mod_device = next(self.image_proj_model.parameters()).device
280
+
281
+ ip_encoder_hidden_states = None
282
+ if ref_image != None:
283
+ ip_encoder_hidden_states = self.image_proj_model(ref_image)
284
+
285
+ # Add ip hidden states to kwargs
286
+ if 'joint_attention_kwargs' not in kwargs:
287
+ kwargs['joint_attention_kwargs'] = {}
288
+ layer_scale = layer_scale.to(dtype=mod_dtype,
289
+ device=mod_device)
290
+
291
+ kwargs['joint_attention_kwargs']['ip_layer_scale'] = layer_scale
292
+ kwargs['joint_attention_kwargs']['ip_hidden_states'] = ip_encoder_hidden_states
293
+
294
+ output = self.transformer(*args,
295
+ **kwargs)
296
+
297
+ return output
298
+
299
+ def save_pretrained(self,ckpt_path):
300
+ """ Save model weights """
301
+ state_dict = {}
302
+
303
+ state_dict["image_proj"] = self.image_proj_model.state_dict()
304
+ state_dict["ip_adapter"] = self.adapter_modules.state_dict()
305
+ torch.save(state_dict, ckpt_path)
306
+
307
+ def load_from_checkpoint(self, ckpt_path):
308
+ """ Loader ripped from tencent repo """
309
+ # Calculate original checksums
310
+ orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
311
+ orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
312
+
313
+ state_dict = torch.load(ckpt_path, map_location="cpu")
314
+
315
+ # Load state dict for image_proj_model and adapter_modules
316
+ self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
317
+ self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
318
+
319
+ # Calculate new checksums
320
+ new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
321
+ new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
322
+
323
+ # Verify if the weights have changed
324
+ assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
325
+ assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
326
+
327
+ print(f"Successfully loaded weights from checkpoint {ckpt_path}")
328
+
329
+
330
+ @property
331
+ def dtype(self):
332
+ return next(self.image_proj_model.parameters()).dtype
image_encoder/config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SiglipVisionModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "dtype": "float32",
7
+ "hidden_act": "gelu_pytorch_tanh",
8
+ "hidden_size": 1152,
9
+ "image_size": 384,
10
+ "intermediate_size": 4304,
11
+ "layer_norm_eps": 1e-06,
12
+ "model_type": "siglip_vision_model",
13
+ "num_attention_heads": 16,
14
+ "num_channels": 3,
15
+ "num_hidden_layers": 27,
16
+ "patch_size": 14,
17
+ "transformers_version": "4.57.1"
18
+ }
image_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:979cc297d21bc2634d7fc0e3fdc2589b0c891acf380708ffdeb1f988e3a2d817
3
+ size 1712957296
image_encoder/preprocessor_config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_convert_rgb": null,
3
+ "do_normalize": true,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.5,
8
+ 0.5,
9
+ 0.5
10
+ ],
11
+ "image_processor_type": "SiglipImageProcessor",
12
+ "image_std": [
13
+ 0.5,
14
+ 0.5,
15
+ 0.5
16
+ ],
17
+ "processor_class": "SiglipProcessor",
18
+ "resample": 3,
19
+ "rescale_factor": 0.00392156862745098,
20
+ "size": {
21
+ "height": 384,
22
+ "width": 384
23
+ }
24
+ }
image_encoder/special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": {
3
+ "content": "</s>",
4
+ "lstrip": true,
5
+ "normalized": false,
6
+ "rstrip": true,
7
+ "single_word": false
8
+ },
9
+ "pad_token": {
10
+ "content": "</s>",
11
+ "lstrip": true,
12
+ "normalized": false,
13
+ "rstrip": true,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<unk>",
18
+ "lstrip": true,
19
+ "normalized": false,
20
+ "rstrip": true,
21
+ "single_word": false
22
+ }
23
+ }
image_encoder/spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e5036bed065526c3c212dfbe288752391797c4bb1a284aa18c9a0b23fcaf8ec
3
+ size 798330
image_encoder/tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "1": {
4
+ "content": "</s>",
5
+ "lstrip": true,
6
+ "normalized": false,
7
+ "rstrip": true,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "2": {
12
+ "content": "<unk>",
13
+ "lstrip": true,
14
+ "normalized": false,
15
+ "rstrip": true,
16
+ "single_word": false,
17
+ "special": true
18
+ }
19
+ },
20
+ "additional_special_tokens": [],
21
+ "clean_up_tokenization_spaces": true,
22
+ "do_lower_case": true,
23
+ "eos_token": "</s>",
24
+ "extra_special_tokens": {},
25
+ "model_input_names": [
26
+ "input_ids"
27
+ ],
28
+ "model_max_length": 64,
29
+ "pad_token": "</s>",
30
+ "processor_class": "SiglipProcessor",
31
+ "sp_model_kwargs": {},
32
+ "tokenizer_class": "SiglipTokenizer",
33
+ "unk_token": "<unk>"
34
+ }
ip_adapter.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a1fcb0c5d436335ae332704585105d547dd4908832cfd361d7a5a8101eefcd0
3
+ size 5291247917
model_index.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "LibreFluxIPAdapterPipeline",
3
+ "_diffusers_version": "0.35.2",
4
+ "image_encoder": [
5
+ "transformers",
6
+ "SiglipVisionModel"
7
+ ],
8
+ "scheduler": [
9
+ "diffusers",
10
+ "FlowMatchEulerDiscreteScheduler"
11
+ ],
12
+ "text_encoder": [
13
+ "transformers",
14
+ "CLIPTextModel"
15
+ ],
16
+ "text_encoder_2": [
17
+ "transformers",
18
+ "T5EncoderModel"
19
+ ],
20
+ "tokenizer": [
21
+ "transformers",
22
+ "CLIPTokenizer"
23
+ ],
24
+ "tokenizer_2": [
25
+ "transformers",
26
+ "T5TokenizerFast"
27
+ ],
28
+ "transformer": [
29
+ "trans",
30
+ "LibreFluxTransformer2DModel"
31
+ ],
32
+ "vae": [
33
+ "diffusers",
34
+ "AutoencoderKL"
35
+ ]
36
+ }
pipeline.py ADDED
@@ -0,0 +1,1107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, The HuggingFace Team, The InstantX Team, and Terminus Research Group. All rights reserved.
2
+ #
3
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ # Originally licensed under the Apache License, Version 2.0 (the "License");
18
+ # Updated to "Affero GENERAL PUBLIC LICENSE Version 3, 19 November 2007" via extensive updates to attn_mask usage.
19
+ #__all__ = ['FluxTransformer2DModelWithMasking', 'CustomPipeline']
20
+
21
+ from typing import Any, Dict, List, Optional, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+
27
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
28
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
29
+ from diffusers.models.attention import FeedForward
30
+ from diffusers.models.attention_processor import Attention
31
+
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.normalization import (
34
+ AdaLayerNormContinuous,
35
+ AdaLayerNormZero,
36
+ AdaLayerNormZeroSingle,
37
+ )
38
+ from diffusers.utils import (
39
+ USE_PEFT_BACKEND,
40
+ is_torch_version,
41
+ logging,
42
+ scale_lora_layers,
43
+ unscale_lora_layers,
44
+ )
45
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
46
+ from diffusers.models.embeddings import (
47
+ CombinedTimestepGuidanceTextProjEmbeddings,
48
+ CombinedTimestepTextProjEmbeddings,
49
+ )
50
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
51
+
52
+ from dataclasses import dataclass
53
+ from typing import List, Union
54
+ import PIL.Image
55
+ from diffusers.utils import BaseOutput
56
+
57
+ import inspect
58
+ from functools import lru_cache
59
+ from typing import Any, Callable, Dict, List, Optional, Union
60
+
61
+ import numpy as np
62
+ import torch
63
+ from transformers import (
64
+ CLIPTextModel,
65
+ CLIPTokenizer,
66
+ T5EncoderModel,
67
+ T5TokenizerFast,
68
+ CLIPVisionModelWithProjection,
69
+ CLIPTextModelWithProjection,
70
+ CLIPImageProcessor
71
+ )
72
+
73
+ from diffusers.image_processor import VaeImageProcessor
74
+ from diffusers.loaders import SD3LoraLoaderMixin
75
+ from diffusers.models.autoencoders import AutoencoderKL
76
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
77
+ from diffusers.utils import (
78
+ USE_PEFT_BACKEND,
79
+ is_torch_xla_available,
80
+ logging,
81
+ replace_example_docstring,
82
+ scale_lora_layers,
83
+ unscale_lora_layers,
84
+ )
85
+ from diffusers.utils.torch_utils import randn_tensor
86
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
87
+
88
+ from PIL import Image
89
+
90
+ from .transformer.trans import *
91
+ from .flux_ip_adapter import *
92
+
93
+ if is_torch_xla_available():
94
+ import torch_xla.core.xla_model as xm
95
+
96
+ XLA_AVAILABLE = True
97
+ else:
98
+ XLA_AVAILABLE = False
99
+
100
+
101
+ @dataclass
102
+ class FluxPipelineOutput(BaseOutput):
103
+ """
104
+ Output class for Stable Diffusion pipelines.
105
+ Args:
106
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
107
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
108
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
109
+ """
110
+
111
+ images: Union[List[PIL.Image.Image], np.ndarray]
112
+
113
+
114
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
115
+
116
+
117
+ EXAMPLE_DOC_STRING = """
118
+ Examples:
119
+ ```py
120
+ >>> import torch
121
+ >>> from diffusers import FluxPipeline
122
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
123
+ >>> pipe.to("cuda")
124
+ >>> prompt = "A cat holding a sign that says hello world"
125
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
126
+ >>> # Refer to the pipeline documentation for more details.
127
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
128
+ >>> image.save("flux.png")
129
+ ```
130
+ """
131
+
132
+
133
+ def calculate_shift(
134
+ image_seq_len,
135
+ base_seq_len: int = 256,
136
+ max_seq_len: int = 4096,
137
+ base_shift: float = 0.5,
138
+ max_shift: float = 1.16,
139
+ ):
140
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
141
+ b = base_shift - m * base_seq_len
142
+ mu = image_seq_len * m + b
143
+ return mu
144
+
145
+
146
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
147
+ def retrieve_timesteps(
148
+ scheduler,
149
+ num_inference_steps: Optional[int] = None,
150
+ device: Optional[Union[str, torch.device]] = None,
151
+ timesteps: Optional[List[int]] = None,
152
+ sigmas: Optional[List[float]] = None,
153
+ **kwargs,
154
+ ):
155
+ """
156
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
157
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
158
+ Args:
159
+ scheduler (`SchedulerMixin`):
160
+ The scheduler to get timesteps from.
161
+ num_inference_steps (`int`):
162
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
163
+ must be `None`.
164
+ device (`str` or `torch.device`, *optional*):
165
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
166
+ timesteps (`List[int]`, *optional*):
167
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
168
+ `num_inference_steps` and `sigmas` must be `None`.
169
+ sigmas (`List[float]`, *optional*):
170
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
171
+ `num_inference_steps` and `timesteps` must be `None`.
172
+ Returns:
173
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
174
+ second element is the number of inference steps.
175
+ """
176
+ if timesteps is not None and sigmas is not None:
177
+ raise ValueError(
178
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
179
+ )
180
+ if timesteps is not None:
181
+ accepts_timesteps = "timesteps" in set(
182
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
183
+ )
184
+ if not accepts_timesteps:
185
+ raise ValueError(
186
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
187
+ f" timestep schedules. Please check whether you are using the correct scheduler."
188
+ )
189
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
190
+ timesteps = scheduler.timesteps
191
+ num_inference_steps = len(timesteps)
192
+ elif sigmas is not None:
193
+ accept_sigmas = "sigmas" in set(
194
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
195
+ )
196
+ if not accept_sigmas:
197
+ raise ValueError(
198
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
199
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
200
+ )
201
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
202
+ timesteps = scheduler.timesteps
203
+ num_inference_steps = len(timesteps)
204
+ else:
205
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
206
+ timesteps = scheduler.timesteps
207
+ return timesteps, num_inference_steps
208
+
209
+
210
+ class LibreFluxIpAdapterPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
211
+ r"""
212
+ The Flux pipeline for text-to-image generation.
213
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
214
+ Args:
215
+ transformer ([`LibreFluxTransformer2DModel`]):
216
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
217
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
218
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
219
+ vae ([`AutoencoderKL`]):
220
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
221
+ text_encoder ([`CLIPTextModelWithProjection`]):
222
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
223
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
224
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
225
+ as its dimension.
226
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
227
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
228
+ specifically the
229
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
230
+ variant.
231
+ tokenizer (`CLIPTokenizer`):
232
+ Tokenizer of class
233
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
234
+ tokenizer_2 (`CLIPTokenizer`):
235
+ Second Tokenizer of class
236
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
237
+ """
238
+
239
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
240
+ _optional_components = []
241
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
242
+
243
+ def __init__(
244
+ self,
245
+ scheduler: FlowMatchEulerDiscreteScheduler,
246
+ vae: AutoencoderKL,
247
+ text_encoder: CLIPTextModel,
248
+ tokenizer: CLIPTokenizer,
249
+ text_encoder_2: T5EncoderModel,
250
+ tokenizer_2: T5TokenizerFast,
251
+ transformer: LibreFluxTransformer2DModel,
252
+ image_encoder: CLIPVisionModelWithProjection,
253
+ ):
254
+ super().__init__()
255
+
256
+ self.ip_adapter = None
257
+
258
+ self.register_modules(
259
+ vae=vae,
260
+ text_encoder=text_encoder,
261
+ text_encoder_2=text_encoder_2,
262
+ tokenizer=tokenizer,
263
+ tokenizer_2=tokenizer_2,
264
+ transformer=transformer,
265
+ scheduler=scheduler,
266
+ image_encoder=image_encoder,
267
+
268
+ )
269
+ self.vae_scale_factor = (
270
+ 2 ** (len(self.vae.config.block_out_channels))
271
+ if hasattr(self, "vae") and self.vae is not None
272
+ else 16
273
+ )
274
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
275
+ self.tokenizer_max_length = (
276
+ self.tokenizer.model_max_length
277
+ if hasattr(self, "tokenizer") and self.tokenizer is not None
278
+ else 77
279
+ )
280
+ self.default_sample_size = 64
281
+
282
+ #self.clip_image_processor = CLIPImageProcessor()
283
+ from transformers import AutoProcessor, SiglipVisionModel
284
+ self.clip_image_processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384")
285
+
286
+ def _get_t5_prompt_embeds(
287
+ self,
288
+ prompt: Union[str, List[str]] = None,
289
+ num_images_per_prompt: int = 1,
290
+ max_sequence_length: int = 512,
291
+ device: Optional[torch.device] = None,
292
+ dtype: Optional[torch.dtype] = None,
293
+ ):
294
+ device = device or self._execution_device
295
+ dtype = dtype or self.text_encoder.dtype
296
+
297
+ prompt = [prompt] if isinstance(prompt, str) else prompt
298
+ batch_size = len(prompt)
299
+
300
+ text_inputs = self.tokenizer_2(
301
+ prompt,
302
+ padding="max_length",
303
+ max_length=max_sequence_length,
304
+ truncation=True,
305
+ return_length=False,
306
+ return_overflowing_tokens=False,
307
+ return_tensors="pt",
308
+ )
309
+ prompt_attention_mask = text_inputs.attention_mask
310
+ text_input_ids = text_inputs.input_ids
311
+ untruncated_ids = self.tokenizer_2(
312
+ prompt, padding="longest", return_tensors="pt"
313
+ ).input_ids
314
+
315
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
316
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
317
+ logger.warning(
318
+ "The following part of your input was truncated because `max_sequence_length` is set to "
319
+ f" {max_sequence_length} tokens: {removed_text}"
320
+ )
321
+
322
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
323
+
324
+ dtype = self.text_encoder_2.dtype
325
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
326
+
327
+ _, seq_len, _ = prompt_embeds.shape
328
+
329
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
330
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
331
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
332
+
333
+ return prompt_embeds, prompt_attention_mask
334
+
335
+ def _get_clip_prompt_embeds(
336
+ self,
337
+ prompt: Union[str, List[str]],
338
+ num_images_per_prompt: int = 1,
339
+ device: Optional[torch.device] = None,
340
+ ):
341
+ device = device or self._execution_device
342
+
343
+ prompt = [prompt] if isinstance(prompt, str) else prompt
344
+ batch_size = len(prompt)
345
+
346
+ text_inputs = self.tokenizer(
347
+ prompt,
348
+ padding="max_length",
349
+ max_length=self.tokenizer_max_length,
350
+ truncation=True,
351
+ return_overflowing_tokens=False,
352
+ return_length=False,
353
+ return_tensors="pt",
354
+ )
355
+
356
+ text_input_ids = text_inputs.input_ids
357
+ untruncated_ids = self.tokenizer(
358
+ prompt, padding="longest", return_tensors="pt"
359
+ ).input_ids
360
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
361
+ text_input_ids, untruncated_ids
362
+ ):
363
+ removed_text = self.tokenizer.batch_decode(
364
+ untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
365
+ )
366
+ logger.warning(
367
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
368
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
369
+ )
370
+ prompt_embeds = self.text_encoder(
371
+ text_input_ids.to(device), output_hidden_states=False
372
+ )
373
+
374
+ # Use pooled output of CLIPTextModel
375
+ prompt_embeds = prompt_embeds.pooler_output
376
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
377
+
378
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
379
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
380
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
381
+
382
+ return prompt_embeds
383
+
384
+ @lru_cache(maxsize=128)
385
+ def encode_prompt(
386
+ self,
387
+ prompt: Union[str, List[str]],
388
+ prompt_2: Union[str, List[str]],
389
+ device: Optional[torch.device] = None,
390
+ num_images_per_prompt: int = 1,
391
+ prompt_embeds: Optional[torch.FloatTensor] = None,
392
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
393
+ max_sequence_length: int = 512,
394
+ lora_scale: Optional[float] = None,
395
+ ):
396
+ r"""
397
+ Args:
398
+ prompt (`str` or `List[str]`, *optional*):
399
+ prompt to be encoded
400
+ prompt_2 (`str` or `List[str]`, *optional*):
401
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
402
+ used in all text-encoders
403
+ device: (`torch.device`):
404
+ torch device
405
+ num_images_per_prompt (`int`):
406
+ number of images that should be generated per prompt
407
+ prompt_embeds (`torch.FloatTensor`, *optional*):
408
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
409
+ provided, text embeddings will be generated from `prompt` input argument.
410
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
411
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
412
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
413
+ clip_skip (`int`, *optional*):
414
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
415
+ the output of the pre-final layer will be used for computing the prompt embeddings.
416
+ lora_scale (`float`, *optional*):
417
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
418
+ """
419
+ device = device or self._execution_device
420
+
421
+ # set lora scale so that monkey patched LoRA
422
+ # function of text encoder can correctly access it
423
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
424
+ self._lora_scale = lora_scale
425
+
426
+ # dynamically adjust the LoRA scale
427
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
428
+ scale_lora_layers(self.text_encoder, lora_scale)
429
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
430
+ scale_lora_layers(self.text_encoder_2, lora_scale)
431
+
432
+ prompt = [prompt] if isinstance(prompt, str) else prompt
433
+ if prompt is not None:
434
+ batch_size = len(prompt)
435
+ else:
436
+ batch_size = prompt_embeds.shape[0]
437
+
438
+ prompt_attention_mask = None
439
+ if prompt_embeds is None:
440
+ prompt_2 = prompt_2 or prompt
441
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
442
+
443
+ # We only use the pooled prompt output from the CLIPTextModel
444
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
445
+ prompt=prompt,
446
+ device=device,
447
+ num_images_per_prompt=num_images_per_prompt,
448
+ )
449
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
450
+ prompt=prompt_2,
451
+ num_images_per_prompt=num_images_per_prompt,
452
+ max_sequence_length=max_sequence_length,
453
+ device=device,
454
+ )
455
+
456
+ if self.text_encoder is not None:
457
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
458
+ # Retrieve the original scale by scaling back the LoRA layers
459
+ unscale_lora_layers(self.text_encoder, lora_scale)
460
+
461
+ if self.text_encoder_2 is not None:
462
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
463
+ # Retrieve the original scale by scaling back the LoRA layers
464
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
465
+
466
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
467
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
468
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
469
+
470
+ return prompt_embeds, pooled_prompt_embeds, text_ids, prompt_attention_mask
471
+
472
+ def check_inputs(
473
+ self,
474
+ prompt,
475
+ prompt_2,
476
+ height,
477
+ width,
478
+ prompt_embeds=None,
479
+ pooled_prompt_embeds=None,
480
+ callback_on_step_end_tensor_inputs=None,
481
+ max_sequence_length=None,
482
+ ):
483
+ if height % 8 != 0 or width % 8 != 0:
484
+ raise ValueError(
485
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
486
+ )
487
+
488
+ if callback_on_step_end_tensor_inputs is not None and not all(
489
+ k in self._callback_tensor_inputs
490
+ for k in callback_on_step_end_tensor_inputs
491
+ ):
492
+ raise ValueError(
493
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
494
+ )
495
+
496
+ if prompt is not None and prompt_embeds is not None:
497
+ raise ValueError(
498
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
499
+ " only forward one of the two."
500
+ )
501
+ elif prompt_2 is not None and prompt_embeds is not None:
502
+ raise ValueError(
503
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
504
+ " only forward one of the two."
505
+ )
506
+ elif prompt is None and prompt_embeds is None:
507
+ raise ValueError(
508
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
509
+ )
510
+ elif prompt is not None and (
511
+ not isinstance(prompt, str) and not isinstance(prompt, list)
512
+ ):
513
+ raise ValueError(
514
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
515
+ )
516
+ elif prompt_2 is not None and (
517
+ not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
518
+ ):
519
+ raise ValueError(
520
+ f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}"
521
+ )
522
+
523
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
524
+ raise ValueError(
525
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
526
+ )
527
+
528
+ if max_sequence_length is not None and max_sequence_length > 512:
529
+ raise ValueError(
530
+ f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}"
531
+ )
532
+
533
+ @staticmethod
534
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
535
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
536
+ latent_image_ids[..., 1] = (
537
+ latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
538
+ )
539
+ latent_image_ids[..., 2] = (
540
+ latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
541
+ )
542
+
543
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
544
+ latent_image_ids.shape
545
+ )
546
+
547
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
548
+ latent_image_ids = latent_image_ids.reshape(
549
+ batch_size,
550
+ latent_image_id_height * latent_image_id_width,
551
+ latent_image_id_channels,
552
+ )
553
+
554
+ return latent_image_ids.to(dtype=dtype, device=device)
555
+
556
+ @staticmethod
557
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
558
+ latents = latents.view(
559
+ batch_size, num_channels_latents, height // 2, 2, width // 2, 2
560
+ )
561
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
562
+ latents = latents.reshape(
563
+ batch_size, (height // 2) * (width // 2), num_channels_latents * 4
564
+ )
565
+
566
+ return latents
567
+
568
+ @staticmethod
569
+ def _unpack_latents(latents, height, width, vae_scale_factor):
570
+ batch_size, num_patches, channels = latents.shape
571
+
572
+ height = height // vae_scale_factor
573
+ width = width // vae_scale_factor
574
+
575
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
576
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
577
+
578
+ latents = latents.reshape(
579
+ batch_size, channels // (2 * 2), height * 2, width * 2
580
+ )
581
+
582
+ return latents
583
+
584
+ def prepare_latents(
585
+ self,
586
+ batch_size,
587
+ num_channels_latents,
588
+ height,
589
+ width,
590
+ dtype,
591
+ device,
592
+ generator,
593
+ latents=None,
594
+ ):
595
+ height = 2 * (int(height) // self.vae_scale_factor)
596
+ width = 2 * (int(width) // self.vae_scale_factor)
597
+
598
+ shape = (batch_size, num_channels_latents, height, width)
599
+
600
+ if latents is not None:
601
+ latent_image_ids = self._prepare_latent_image_ids(
602
+ batch_size, height, width, device, dtype
603
+ )
604
+ return latents, latent_image_ids
605
+
606
+ if isinstance(generator, list) and len(generator) != batch_size:
607
+ raise ValueError(
608
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
609
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
610
+ )
611
+
612
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
613
+ latents = self._pack_latents(
614
+ latents, batch_size, num_channels_latents, height, width
615
+ )
616
+
617
+ latent_image_ids = self._prepare_latent_image_ids(
618
+ batch_size, height, width, device, dtype
619
+ )
620
+
621
+ return latents, latent_image_ids
622
+
623
+ @property
624
+ def guidance_scale(self):
625
+ return self._guidance_scale
626
+
627
+ @property
628
+ def joint_attention_kwargs(self):
629
+ return self._joint_attention_kwargs
630
+
631
+ @property
632
+ def num_timesteps(self):
633
+ return self._num_timesteps
634
+
635
+ @property
636
+ def interrupt(self):
637
+ return self._interrupt
638
+
639
+ @torch.no_grad()
640
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
641
+ def __call__(
642
+ self,
643
+ prompt: Union[str, List[str]] = None,
644
+ prompt_mask: Optional[Union[torch.FloatTensor, List[torch.FloatTensor]]] = None,
645
+ negative_mask: Optional[
646
+ Union[torch.FloatTensor, List[torch.FloatTensor]]
647
+ ] = None,
648
+ prompt_2: Optional[Union[str, List[str]]] = None,
649
+ height: Optional[int] = None,
650
+ width: Optional[int] = None,
651
+ num_inference_steps: int = 28,
652
+ timesteps: List[int] = None,
653
+ guidance_scale: float = 3.5,
654
+ num_images_per_prompt: Optional[int] = 1,
655
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
656
+ latents: Optional[torch.FloatTensor] = None,
657
+ prompt_embeds: Optional[torch.FloatTensor] = None,
658
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
659
+ output_type: Optional[str] = "pil",
660
+ return_dict: bool = True,
661
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
662
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
663
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
664
+ max_sequence_length: int = 512,
665
+ guidance_scale_real: float = 1.0,
666
+ negative_prompt: Union[str, List[str]] = "",
667
+ negative_prompt_2: Union[str, List[str]] = "",
668
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
669
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
670
+ no_cfg_until_timestep: int = 0,
671
+ do_batch_cfg: bool=True,
672
+ ip_adapter_image: Image=None,
673
+ ip_adapter_scale: float=1.0,
674
+ device=torch.device('cuda'), # TODO let this work with non-cuda stuff? Might if you set this to None
675
+ ):
676
+ r"""
677
+ Function invoked when calling the pipeline for generation.
678
+ Args:
679
+ prompt (`str` or `List[str]`, *optional*):
680
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
681
+ instead.
682
+ prompt_mask (`str` or `List[str]`, *optional*):
683
+ The prompt or prompts to be used as a mask for the image generation. If not defined, `prompt` is used
684
+ instead.
685
+ prompt_2 (`str` or `List[str]`, *optional*):
686
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
687
+ will be used instead
688
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
689
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
690
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
691
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
692
+ num_inference_steps (`int`, *optional*, defaults to 50):
693
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
694
+ expense of slower inference.
695
+ timesteps (`List[int]`, *optional*):
696
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
697
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
698
+ passed will be used. Must be in descending order.
699
+ guidance_scale (`float`, *optional*, defaults to 7.0):
700
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
701
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
702
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
703
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
704
+ usually at the expense of lower image quality.
705
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
706
+ The number of images to generate per prompt.
707
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
708
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
709
+ to make generation deterministic.
710
+ latents (`torch.FloatTensor`, *optional*):
711
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
712
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
713
+ tensor will ge generated by sampling using the supplied random `generator`.
714
+ prompt_embeds (`torch.FloatTensor`, *optional*):
715
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
716
+ provided, text embeddings will be generated from `prompt` input argument.
717
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
718
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
719
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
720
+ output_type (`str`, *optional*, defaults to `"pil"`):
721
+ The output format of the generate image. Choose between
722
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
723
+ return_dict (`bool`, *optional*, defaults to `True`):
724
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
725
+ joint_attention_kwargs (`dict`, *optional*):
726
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
727
+ `self.processor` in
728
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
729
+ callback_on_step_end (`Callable`, *optional*):
730
+ A function that calls at the end of each denoising steps during the inference. The function is called
731
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
732
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
733
+ `callback_on_step_end_tensor_inputs`.
734
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
735
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
736
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
737
+ `._callback_tensor_inputs` attribute of your pipeline class.
738
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
739
+ Examples:
740
+ Returns:
741
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
742
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
743
+ images.
744
+ """
745
+
746
+ height = height or self.default_sample_size * self.vae_scale_factor
747
+ width = width or self.default_sample_size * self.vae_scale_factor
748
+
749
+ # 1. Check inputs. Raise error if not correct
750
+ self.check_inputs(
751
+ prompt,
752
+ prompt_2,
753
+ height,
754
+ width,
755
+ prompt_embeds=prompt_embeds,
756
+ pooled_prompt_embeds=pooled_prompt_embeds,
757
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
758
+ max_sequence_length=max_sequence_length,
759
+ )
760
+
761
+ # guidance_scale_real is redundant because this pipeline was originally
762
+ # made to be backwards compatible, but to make it the default just set
763
+ # guidance scale to be the same things.
764
+ guidance_scale_real = guidance_scale
765
+
766
+ self._guidance_scale = guidance_scale
767
+ self._guidance_scale_real = guidance_scale_real
768
+ self._joint_attention_kwargs = joint_attention_kwargs
769
+ self._interrupt = False
770
+
771
+ # 2. Define call parameters
772
+ if prompt is not None and isinstance(prompt, str):
773
+ batch_size = 1
774
+ elif prompt is not None and isinstance(prompt, list):
775
+ batch_size = len(prompt)
776
+ else:
777
+ batch_size = prompt_embeds.shape[0]
778
+
779
+ device = device or self._execution_device
780
+
781
+ lora_scale = (
782
+ self.joint_attention_kwargs.get("scale", None)
783
+ if self.joint_attention_kwargs is not None
784
+ else None
785
+ )
786
+ (
787
+ prompt_embeds,
788
+ pooled_prompt_embeds,
789
+ text_ids,
790
+ _prompt_mask,
791
+ ) = self.encode_prompt(
792
+ prompt=prompt,
793
+ prompt_2=prompt_2,
794
+ prompt_embeds=prompt_embeds,
795
+ pooled_prompt_embeds=pooled_prompt_embeds,
796
+ device=device,
797
+ num_images_per_prompt=num_images_per_prompt,
798
+ max_sequence_length=max_sequence_length,
799
+ lora_scale=lora_scale,
800
+ )
801
+ if _prompt_mask is not None:
802
+ prompt_mask = _prompt_mask
803
+ assert prompt_mask is not None
804
+
805
+ if negative_prompt_2 == "" and negative_prompt != "":
806
+ negative_prompt_2 = negative_prompt
807
+
808
+ negative_text_ids = text_ids
809
+ if self._guidance_scale_real > 1.0 and (
810
+ negative_prompt_embeds is None or negative_pooled_prompt_embeds is None
811
+ ):
812
+ (
813
+ negative_prompt_embeds,
814
+ negative_pooled_prompt_embeds,
815
+ negative_text_ids,
816
+ _neg_prompt_mask,
817
+ ) = self.encode_prompt(
818
+ prompt=negative_prompt,
819
+ prompt_2=negative_prompt_2,
820
+ prompt_embeds=None,
821
+ pooled_prompt_embeds=None,
822
+ device=device,
823
+ num_images_per_prompt=num_images_per_prompt,
824
+ max_sequence_length=max_sequence_length,
825
+ lora_scale=lora_scale,
826
+ )
827
+
828
+ if _neg_prompt_mask is not None:
829
+ negative_mask = _neg_prompt_mask
830
+
831
+ assert negative_mask is not None
832
+
833
+ # 4. Prepare latent variables
834
+ num_channels_latents = self.transformer.config.in_channels // 4
835
+ latents, latent_image_ids = self.prepare_latents(
836
+ batch_size * num_images_per_prompt,
837
+ num_channels_latents,
838
+ height,
839
+ width,
840
+ prompt_embeds.dtype,
841
+ device,
842
+ generator,
843
+ latents,
844
+ )
845
+
846
+ # 5. Prepare timesteps
847
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
848
+ image_seq_len = latents.shape[1]
849
+ mu = calculate_shift(
850
+ image_seq_len,
851
+ self.scheduler.config.base_image_seq_len,
852
+ self.scheduler.config.max_image_seq_len,
853
+ self.scheduler.config.base_shift,
854
+ self.scheduler.config.max_shift,
855
+ )
856
+ timesteps, num_inference_steps = retrieve_timesteps(
857
+ self.scheduler,
858
+ num_inference_steps,
859
+ device,
860
+ timesteps,
861
+ sigmas,
862
+ mu=mu,
863
+ )
864
+ num_warmup_steps = max(
865
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
866
+ )
867
+ self._num_timesteps = len(timesteps)
868
+
869
+ latents = latents
870
+ latent_image_ids = latent_image_ids
871
+ timesteps = timesteps
872
+ text_ids = text_ids.to(device=device)
873
+
874
+ # handle guidance
875
+ if self.transformer.config.guidance_embeds:
876
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
877
+ guidance = guidance.expand(latents.shape[0])
878
+ else:
879
+ guidance = None
880
+
881
+ # Denoising loop
882
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
883
+ for i, t in enumerate(timesteps):
884
+ if self.interrupt:
885
+ continue
886
+
887
+ # Prepare the latent model input
888
+ prompt_embeds_input = prompt_embeds
889
+ pooled_prompt_embeds_input = pooled_prompt_embeds
890
+ text_ids_input = text_ids
891
+ latent_image_ids_input = latent_image_ids
892
+ prompt_mask_input = prompt_mask
893
+ latent_model_input = latents
894
+
895
+ if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
896
+ progress_bar.set_postfix(
897
+ {
898
+ 'ts': t.detach().item() / 1000.0,
899
+ 'cfg': self._guidance_scale_real,
900
+ },
901
+ )
902
+ else:
903
+ progress_bar.set_postfix(
904
+ {
905
+ 'ts': t.detach().item() / 1000.0,
906
+ 'cfg': 'N/A',
907
+ },
908
+ )
909
+
910
+ # Forward pass through the transformer
911
+ with torch.no_grad():
912
+ if ip_adapter_image != None:
913
+
914
+ clip_image = self.clip_image_processor(images=ip_adapter_image,
915
+ return_tensors="pt").pixel_values
916
+ clip_image = clip_image.to(device=self.image_encoder.device,
917
+ dtype=self.image_encoder.dtype)
918
+ image_embeds = self.image_encoder(clip_image).pooler_output
919
+ image_embeds_input = image_embeds
920
+ else:
921
+ image_embeds = None
922
+ image_embeds_input = None
923
+
924
+ layer_scale = torch.Tensor([ip_adapter_scale])
925
+ layer_scale_input = layer_scale
926
+ neg_layer_scale = torch.Tensor([0.0])
927
+
928
+
929
+ if do_batch_cfg and guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
930
+ # Concatenate prompt embeddings
931
+ prompt_embeds_input = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
932
+ pooled_prompt_embeds_input = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
933
+
934
+ image_embeds_input = None
935
+ if image_embeds != None:
936
+ image_embeds_input = torch.cat([image_embeds]*2, dim=0)
937
+
938
+ layer_scale_input = torch.cat([neg_layer_scale , layer_scale], dim=0)
939
+ # Concatenate text IDs if they are used
940
+ # if text_ids is not None and negative_text_ids is not None:
941
+ # text_ids_input = torch.cat([negative_text_ids, text_ids], dim=0)
942
+
943
+ # Concatenate latent image IDs if they are used
944
+ # if latent_image_ids is not None:
945
+ # latent_image_ids_input = torch.cat([latent_image_ids, latent_image_ids], dim=0)
946
+
947
+ # Concatenate prompt masks if they are used
948
+ if prompt_mask is not None and negative_mask is not None:
949
+ prompt_mask_input = torch.cat([negative_mask, prompt_mask], dim=0)
950
+ # Duplicate latents for unconditional and conditional inputs
951
+ latent_model_input = torch.cat([latents] * 2)
952
+
953
+ # Expand timestep to match batch size
954
+ timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
955
+
956
+ guidance = None
957
+
958
+ div_timestep = (timestep / 1000.0)
959
+ text_ids = [ t for t in text_ids ]
960
+
961
+ if self.ip_adapter == None:
962
+ noise_pred = self.transformer(
963
+ latent_model_input,
964
+ timestep=div_timestep.to(device=self.transformer.device),
965
+ guidance=guidance,
966
+ pooled_projections=pooled_prompt_embeds_input.to(device=self.transformer.device),
967
+ encoder_hidden_states=prompt_embeds_input.to(device=self.transformer.device),
968
+ attention_mask=prompt_mask_input.to(device=self.transformer.device),
969
+ txt_ids=text_ids_input[0],
970
+ img_ids=latent_image_ids_input[0].to(device=self.transformer.device),
971
+ return_dict=False,
972
+ )[0]
973
+ else:
974
+ noise_pred = self.ip_adapter(
975
+ image_embeds_input,
976
+ latent_model_input.to(device=self.transformer.device),
977
+ layer_scale=layer_scale_input,
978
+ timestep=div_timestep.to(device=self.transformer.device),
979
+ guidance=guidance,
980
+ pooled_projections=pooled_prompt_embeds_input.to(device=self.transformer.device),
981
+ encoder_hidden_states=prompt_embeds_input.to(device=self.transformer.device),
982
+ attention_mask=prompt_mask_input.to(device=self.transformer.device),
983
+ txt_ids=text_ids_input[0],
984
+ img_ids=latent_image_ids_input[0].to(device=self.transformer.device),
985
+ return_dict=False,
986
+ )[0]
987
+
988
+ # Apply real CFG
989
+ if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
990
+ if do_batch_cfg:
991
+ # Batched CFG: Split the noise prediction into unconditional and conditional parts
992
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
993
+ noise_pred = noise_pred_uncond + guidance_scale_real * (noise_pred_cond - noise_pred_uncond)
994
+ else:
995
+ # Sequential CFG: Compute unconditional noise prediction separately
996
+ if self.ip_adapter == None:
997
+ noise_pred_uncond = self.transformer(
998
+ latents.to(device=self.transformer.device),
999
+ timestep=div_timestep,
1000
+ guidance=guidance,
1001
+ pooled_projections=negative_pooled_prompt_embeds.to(device=self.transformer.device),
1002
+ encoder_hidden_states=negative_prompt_embeds.to(device=self.transformer.device),
1003
+ attention_mask=negative_mask,
1004
+ txt_ids=negative_text_ids.to(device=self.transformer.device) if negative_text_ids is not None else None,
1005
+ img_ids=latent_image_ids[0].to(device=self.transformer.device),
1006
+ return_dict=False,
1007
+ )[0]
1008
+ else:
1009
+ noise_pred_uncond = self.ip_adapter(
1010
+ image_embeds,
1011
+ latents.to(device=self.transformer.device),
1012
+ layer_scale=neg_layer_scale,
1013
+ timestep=div_timestep,
1014
+ guidance=guidance,
1015
+ pooled_projections=negative_pooled_prompt_embeds.to(device=self.transformer.device),
1016
+ encoder_hidden_states=negative_prompt_embeds.to(device=self.transformer.device),
1017
+ attention_mask=negative_mask,
1018
+ txt_ids=negative_text_ids.to(device=self.transformer.device) if negative_text_ids is not None else None,
1019
+ img_ids=latent_image_ids[0].to(device=self.transformer.device),
1020
+ return_dict=False,
1021
+ )[0]
1022
+
1023
+ # Combine conditional and unconditional predictions
1024
+ noise_pred = noise_pred_uncond + guidance_scale_real * (noise_pred - noise_pred_uncond)
1025
+
1026
+ # Compute the previous noisy sample x_t -> x_t-1
1027
+ latents_dtype = latents.dtype
1028
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1029
+
1030
+ # Ensure latents have the correct dtype
1031
+ if latents.dtype != latents_dtype:
1032
+ if torch.backends.mps.is_available():
1033
+ latents = latents.to(latents_dtype)
1034
+
1035
+ # Callback at the end of the step, if provided
1036
+ if callback_on_step_end is not None:
1037
+ callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs}
1038
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1039
+ latents = callback_outputs.get("latents", latents)
1040
+ prompt_embeds = callback_outputs.get("prompt_embeds", prompt_embeds)
1041
+
1042
+ # Update the progress bar
1043
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1044
+ progress_bar.update()
1045
+
1046
+ # Mark step for XLA devices
1047
+ if XLA_AVAILABLE:
1048
+ xm.mark_step()
1049
+
1050
+ if output_type == "latent":
1051
+ image = latents
1052
+
1053
+ else:
1054
+ latents = self._unpack_latents(
1055
+ latents, height, width, self.vae_scale_factor
1056
+ )
1057
+ latents = (
1058
+ latents / self.vae.config.scaling_factor
1059
+ ) + self.vae.config.shift_factor
1060
+
1061
+ latents = latents.to(dtype=self.vae.dtype)
1062
+
1063
+ image = self.vae.decode(
1064
+ latents,
1065
+ return_dict=False,
1066
+ )[0]
1067
+ image = self.image_processor.postprocess(image, output_type=output_type)
1068
+
1069
+ # Offload all models
1070
+ self.maybe_free_model_hooks()
1071
+
1072
+ if not return_dict:
1073
+ return (image,)
1074
+
1075
+ return FluxPipelineOutput(images=image)
1076
+
1077
+ def to(self, *args, **kwargs):
1078
+ """
1079
+ Overrides the default .to() method to also move the
1080
+ unregistered ip_adapter.
1081
+ """
1082
+ super().to(*args, **kwargs)
1083
+
1084
+ # 2. Manually move your unregistered ip_adapter
1085
+ # It receives all the same args (like device, dtype)
1086
+ if self.ip_adapter != None:
1087
+ self.ip_adapter.to(*args, **kwargs)
1088
+
1089
+ # 3. Return `self` to allow for chaining (e.g., pipe.to(device).half())
1090
+ return self
1091
+
1092
+ def load_ip_adapter(self, checkpoint_path):
1093
+ """ Init model and load weights, or just load weights"""
1094
+
1095
+ if self.ip_adapter == None:
1096
+ image_proj_model = ImageProjModel( clip_dim = self.image_encoder.config.hidden_size,
1097
+ cross_attention_dim=4096,
1098
+ num_tokens=128)
1099
+
1100
+
1101
+ self.ip_adapter = LibreFluxIPAdapter(self.transformer,
1102
+ image_proj_model)
1103
+
1104
+ self.ip_adapter.load_from_checkpoint(checkpoint_path)
1105
+ self.ip_adapter.to(self.transformer.device,dtype=self.dtype)
1106
+
1107
+ return self
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.35.2",
4
+ "base_image_seq_len": 256,
5
+ "base_shift": 0.5,
6
+ "max_image_seq_len": 4096,
7
+ "max_shift": 1.15,
8
+ "num_train_timesteps": 1000,
9
+ "shift": 1.0,
10
+ "use_dynamic_shifting": false
11
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "openai/clip-vit-large-patch14",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "quick_gelu",
11
+ "hidden_size": 768,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 768,
22
+ "torch_dtype": "bfloat16",
23
+ "transformers_version": "4.43.3",
24
+ "vocab_size": 49408
25
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:893d67a23f4693ed42cdab4cbad7fe3e727cf59609c40da28a46b5470f9ed082
3
+ size 246144352
text_encoder_2/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "google/t5-v1_1-xxl",
3
+ "architectures": [
4
+ "T5EncoderModel"
5
+ ],
6
+ "classifier_dropout": 0.0,
7
+ "d_ff": 10240,
8
+ "d_kv": 64,
9
+ "d_model": 4096,
10
+ "decoder_start_token_id": 0,
11
+ "dense_act_fn": "gelu_new",
12
+ "dropout_rate": 0.1,
13
+ "eos_token_id": 1,
14
+ "feed_forward_proj": "gated-gelu",
15
+ "initializer_factor": 1.0,
16
+ "is_encoder_decoder": true,
17
+ "is_gated_act": true,
18
+ "layer_norm_epsilon": 1e-06,
19
+ "model_type": "t5",
20
+ "num_decoder_layers": 24,
21
+ "num_heads": 64,
22
+ "num_layers": 24,
23
+ "output_past": true,
24
+ "pad_token_id": 0,
25
+ "relative_attention_max_distance": 128,
26
+ "relative_attention_num_buckets": 32,
27
+ "tie_word_embeddings": false,
28
+ "torch_dtype": "bfloat16",
29
+ "transformers_version": "4.43.3",
30
+ "use_cache": true,
31
+ "vocab_size": 32128
32
+ }
text_encoder_2/model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec87bffd1923e8b2774a6d240c922a41f6143081d52cf83b8fe39e9d838c893e
3
+ size 4994582224
text_encoder_2/model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5640855b301fcdbceddfa90ae8066cd9414aff020552a201a255ecf2059da00
3
+ size 4530066360
text_encoder_2/model.safetensors.index.json ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 9524621312
4
+ },
5
+ "weight_map": {
6
+ "encoder.block.0.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
7
+ "encoder.block.0.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
8
+ "encoder.block.0.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
9
+ "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight": "model-00001-of-00002.safetensors",
10
+ "encoder.block.0.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
11
+ "encoder.block.0.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
12
+ "encoder.block.0.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
13
+ "encoder.block.0.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
14
+ "encoder.block.0.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
15
+ "encoder.block.0.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
16
+ "encoder.block.1.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
17
+ "encoder.block.1.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
18
+ "encoder.block.1.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
19
+ "encoder.block.1.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
20
+ "encoder.block.1.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
21
+ "encoder.block.1.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
22
+ "encoder.block.1.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
23
+ "encoder.block.1.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
24
+ "encoder.block.1.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
25
+ "encoder.block.10.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
26
+ "encoder.block.10.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
27
+ "encoder.block.10.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
28
+ "encoder.block.10.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
29
+ "encoder.block.10.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
30
+ "encoder.block.10.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
31
+ "encoder.block.10.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
32
+ "encoder.block.10.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
33
+ "encoder.block.10.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
34
+ "encoder.block.11.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
35
+ "encoder.block.11.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
36
+ "encoder.block.11.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
37
+ "encoder.block.11.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
38
+ "encoder.block.11.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
39
+ "encoder.block.11.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
40
+ "encoder.block.11.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
41
+ "encoder.block.11.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
42
+ "encoder.block.11.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
43
+ "encoder.block.12.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
44
+ "encoder.block.12.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
45
+ "encoder.block.12.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
46
+ "encoder.block.12.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
47
+ "encoder.block.12.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
48
+ "encoder.block.12.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
49
+ "encoder.block.12.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
50
+ "encoder.block.12.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
51
+ "encoder.block.12.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
52
+ "encoder.block.13.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
53
+ "encoder.block.13.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
54
+ "encoder.block.13.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
55
+ "encoder.block.13.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
56
+ "encoder.block.13.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
57
+ "encoder.block.13.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
58
+ "encoder.block.13.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
59
+ "encoder.block.13.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
60
+ "encoder.block.13.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
61
+ "encoder.block.14.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
62
+ "encoder.block.14.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
63
+ "encoder.block.14.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
64
+ "encoder.block.14.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
65
+ "encoder.block.14.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
66
+ "encoder.block.14.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
67
+ "encoder.block.14.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
68
+ "encoder.block.14.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
69
+ "encoder.block.14.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
70
+ "encoder.block.15.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
71
+ "encoder.block.15.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
72
+ "encoder.block.15.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
73
+ "encoder.block.15.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
74
+ "encoder.block.15.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
75
+ "encoder.block.15.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
76
+ "encoder.block.15.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
77
+ "encoder.block.15.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
78
+ "encoder.block.15.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
79
+ "encoder.block.16.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
80
+ "encoder.block.16.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
81
+ "encoder.block.16.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
82
+ "encoder.block.16.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
83
+ "encoder.block.16.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
84
+ "encoder.block.16.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
85
+ "encoder.block.16.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
86
+ "encoder.block.16.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
87
+ "encoder.block.16.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
88
+ "encoder.block.17.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
89
+ "encoder.block.17.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
90
+ "encoder.block.17.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
91
+ "encoder.block.17.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
92
+ "encoder.block.17.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
93
+ "encoder.block.17.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
94
+ "encoder.block.17.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
95
+ "encoder.block.17.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
96
+ "encoder.block.17.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
97
+ "encoder.block.18.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
98
+ "encoder.block.18.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
99
+ "encoder.block.18.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
100
+ "encoder.block.18.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
101
+ "encoder.block.18.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
102
+ "encoder.block.18.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
103
+ "encoder.block.18.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
104
+ "encoder.block.18.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
105
+ "encoder.block.18.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
106
+ "encoder.block.19.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
107
+ "encoder.block.19.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
108
+ "encoder.block.19.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
109
+ "encoder.block.19.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
110
+ "encoder.block.19.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
111
+ "encoder.block.19.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
112
+ "encoder.block.19.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
113
+ "encoder.block.19.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
114
+ "encoder.block.19.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
115
+ "encoder.block.2.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
116
+ "encoder.block.2.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
117
+ "encoder.block.2.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
118
+ "encoder.block.2.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
119
+ "encoder.block.2.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
120
+ "encoder.block.2.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
121
+ "encoder.block.2.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
122
+ "encoder.block.2.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
123
+ "encoder.block.2.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
124
+ "encoder.block.20.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
125
+ "encoder.block.20.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
126
+ "encoder.block.20.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
127
+ "encoder.block.20.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
128
+ "encoder.block.20.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
129
+ "encoder.block.20.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
130
+ "encoder.block.20.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
131
+ "encoder.block.20.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
132
+ "encoder.block.20.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
133
+ "encoder.block.21.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
134
+ "encoder.block.21.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
135
+ "encoder.block.21.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
136
+ "encoder.block.21.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
137
+ "encoder.block.21.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
138
+ "encoder.block.21.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
139
+ "encoder.block.21.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
140
+ "encoder.block.21.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
141
+ "encoder.block.21.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
142
+ "encoder.block.22.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
143
+ "encoder.block.22.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
144
+ "encoder.block.22.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
145
+ "encoder.block.22.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
146
+ "encoder.block.22.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
147
+ "encoder.block.22.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
148
+ "encoder.block.22.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
149
+ "encoder.block.22.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
150
+ "encoder.block.22.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
151
+ "encoder.block.23.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
152
+ "encoder.block.23.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
153
+ "encoder.block.23.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
154
+ "encoder.block.23.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
155
+ "encoder.block.23.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
156
+ "encoder.block.23.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
157
+ "encoder.block.23.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
158
+ "encoder.block.23.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
159
+ "encoder.block.23.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
160
+ "encoder.block.3.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
161
+ "encoder.block.3.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
162
+ "encoder.block.3.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
163
+ "encoder.block.3.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
164
+ "encoder.block.3.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
165
+ "encoder.block.3.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
166
+ "encoder.block.3.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
167
+ "encoder.block.3.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
168
+ "encoder.block.3.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
169
+ "encoder.block.4.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
170
+ "encoder.block.4.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
171
+ "encoder.block.4.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
172
+ "encoder.block.4.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
173
+ "encoder.block.4.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
174
+ "encoder.block.4.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
175
+ "encoder.block.4.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
176
+ "encoder.block.4.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
177
+ "encoder.block.4.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
178
+ "encoder.block.5.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
179
+ "encoder.block.5.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
180
+ "encoder.block.5.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
181
+ "encoder.block.5.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
182
+ "encoder.block.5.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
183
+ "encoder.block.5.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
184
+ "encoder.block.5.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
185
+ "encoder.block.5.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
186
+ "encoder.block.5.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
187
+ "encoder.block.6.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
188
+ "encoder.block.6.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
189
+ "encoder.block.6.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
190
+ "encoder.block.6.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
191
+ "encoder.block.6.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
192
+ "encoder.block.6.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
193
+ "encoder.block.6.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
194
+ "encoder.block.6.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
195
+ "encoder.block.6.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
196
+ "encoder.block.7.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
197
+ "encoder.block.7.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
198
+ "encoder.block.7.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
199
+ "encoder.block.7.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
200
+ "encoder.block.7.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
201
+ "encoder.block.7.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
202
+ "encoder.block.7.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
203
+ "encoder.block.7.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
204
+ "encoder.block.7.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
205
+ "encoder.block.8.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
206
+ "encoder.block.8.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
207
+ "encoder.block.8.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
208
+ "encoder.block.8.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
209
+ "encoder.block.8.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
210
+ "encoder.block.8.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
211
+ "encoder.block.8.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
212
+ "encoder.block.8.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
213
+ "encoder.block.8.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
214
+ "encoder.block.9.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
215
+ "encoder.block.9.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
216
+ "encoder.block.9.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
217
+ "encoder.block.9.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
218
+ "encoder.block.9.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
219
+ "encoder.block.9.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
220
+ "encoder.block.9.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
221
+ "encoder.block.9.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
222
+ "encoder.block.9.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
223
+ "encoder.final_layer_norm.weight": "model-00002-of-00002.safetensors",
224
+ "shared.weight": "model-00001-of-00002.safetensors"
225
+ }
226
+ }
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|endoftext|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "49406": {
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49407": {
13
+ "content": "<|endoftext|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ }
20
+ },
21
+ "bos_token": "<|startoftext|>",
22
+ "clean_up_tokenization_spaces": true,
23
+ "do_lower_case": true,
24
+ "eos_token": "<|endoftext|>",
25
+ "errors": "replace",
26
+ "model_max_length": 77,
27
+ "pad_token": "<|endoftext|>",
28
+ "tokenizer_class": "CLIPTokenizer",
29
+ "unk_token": "<|endoftext|>"
30
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_2/special_tokens_map.json ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>"
103
+ ],
104
+ "eos_token": {
105
+ "content": "</s>",
106
+ "lstrip": false,
107
+ "normalized": false,
108
+ "rstrip": false,
109
+ "single_word": false
110
+ },
111
+ "pad_token": {
112
+ "content": "<pad>",
113
+ "lstrip": false,
114
+ "normalized": false,
115
+ "rstrip": false,
116
+ "single_word": false
117
+ },
118
+ "unk_token": {
119
+ "content": "<unk>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false
124
+ }
125
+ }
tokenizer_2/spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656
tokenizer_2/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_2/tokenizer_config.json ADDED
@@ -0,0 +1,940 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": true,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<pad>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "</s>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "<unk>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "32000": {
29
+ "content": "<extra_id_99>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "32001": {
37
+ "content": "<extra_id_98>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "32002": {
45
+ "content": "<extra_id_97>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "32003": {
53
+ "content": "<extra_id_96>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "32004": {
61
+ "content": "<extra_id_95>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "32005": {
69
+ "content": "<extra_id_94>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "32006": {
77
+ "content": "<extra_id_93>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "32007": {
85
+ "content": "<extra_id_92>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "32008": {
93
+ "content": "<extra_id_91>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "32009": {
101
+ "content": "<extra_id_90>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "32010": {
109
+ "content": "<extra_id_89>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "32011": {
117
+ "content": "<extra_id_88>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "32012": {
125
+ "content": "<extra_id_87>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "32013": {
133
+ "content": "<extra_id_86>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "32014": {
141
+ "content": "<extra_id_85>",
142
+ "lstrip": false,
143
+ "normalized": false,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "32015": {
149
+ "content": "<extra_id_84>",
150
+ "lstrip": false,
151
+ "normalized": false,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": true
155
+ },
156
+ "32016": {
157
+ "content": "<extra_id_83>",
158
+ "lstrip": false,
159
+ "normalized": false,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": true
163
+ },
164
+ "32017": {
165
+ "content": "<extra_id_82>",
166
+ "lstrip": false,
167
+ "normalized": false,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": true
171
+ },
172
+ "32018": {
173
+ "content": "<extra_id_81>",
174
+ "lstrip": false,
175
+ "normalized": false,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": true
179
+ },
180
+ "32019": {
181
+ "content": "<extra_id_80>",
182
+ "lstrip": false,
183
+ "normalized": false,
184
+ "rstrip": false,
185
+ "single_word": false,
186
+ "special": true
187
+ },
188
+ "32020": {
189
+ "content": "<extra_id_79>",
190
+ "lstrip": false,
191
+ "normalized": false,
192
+ "rstrip": false,
193
+ "single_word": false,
194
+ "special": true
195
+ },
196
+ "32021": {
197
+ "content": "<extra_id_78>",
198
+ "lstrip": false,
199
+ "normalized": false,
200
+ "rstrip": false,
201
+ "single_word": false,
202
+ "special": true
203
+ },
204
+ "32022": {
205
+ "content": "<extra_id_77>",
206
+ "lstrip": false,
207
+ "normalized": false,
208
+ "rstrip": false,
209
+ "single_word": false,
210
+ "special": true
211
+ },
212
+ "32023": {
213
+ "content": "<extra_id_76>",
214
+ "lstrip": false,
215
+ "normalized": false,
216
+ "rstrip": false,
217
+ "single_word": false,
218
+ "special": true
219
+ },
220
+ "32024": {
221
+ "content": "<extra_id_75>",
222
+ "lstrip": false,
223
+ "normalized": false,
224
+ "rstrip": false,
225
+ "single_word": false,
226
+ "special": true
227
+ },
228
+ "32025": {
229
+ "content": "<extra_id_74>",
230
+ "lstrip": false,
231
+ "normalized": false,
232
+ "rstrip": false,
233
+ "single_word": false,
234
+ "special": true
235
+ },
236
+ "32026": {
237
+ "content": "<extra_id_73>",
238
+ "lstrip": false,
239
+ "normalized": false,
240
+ "rstrip": false,
241
+ "single_word": false,
242
+ "special": true
243
+ },
244
+ "32027": {
245
+ "content": "<extra_id_72>",
246
+ "lstrip": false,
247
+ "normalized": false,
248
+ "rstrip": false,
249
+ "single_word": false,
250
+ "special": true
251
+ },
252
+ "32028": {
253
+ "content": "<extra_id_71>",
254
+ "lstrip": false,
255
+ "normalized": false,
256
+ "rstrip": false,
257
+ "single_word": false,
258
+ "special": true
259
+ },
260
+ "32029": {
261
+ "content": "<extra_id_70>",
262
+ "lstrip": false,
263
+ "normalized": false,
264
+ "rstrip": false,
265
+ "single_word": false,
266
+ "special": true
267
+ },
268
+ "32030": {
269
+ "content": "<extra_id_69>",
270
+ "lstrip": false,
271
+ "normalized": false,
272
+ "rstrip": false,
273
+ "single_word": false,
274
+ "special": true
275
+ },
276
+ "32031": {
277
+ "content": "<extra_id_68>",
278
+ "lstrip": false,
279
+ "normalized": false,
280
+ "rstrip": false,
281
+ "single_word": false,
282
+ "special": true
283
+ },
284
+ "32032": {
285
+ "content": "<extra_id_67>",
286
+ "lstrip": false,
287
+ "normalized": false,
288
+ "rstrip": false,
289
+ "single_word": false,
290
+ "special": true
291
+ },
292
+ "32033": {
293
+ "content": "<extra_id_66>",
294
+ "lstrip": false,
295
+ "normalized": false,
296
+ "rstrip": false,
297
+ "single_word": false,
298
+ "special": true
299
+ },
300
+ "32034": {
301
+ "content": "<extra_id_65>",
302
+ "lstrip": false,
303
+ "normalized": false,
304
+ "rstrip": false,
305
+ "single_word": false,
306
+ "special": true
307
+ },
308
+ "32035": {
309
+ "content": "<extra_id_64>",
310
+ "lstrip": false,
311
+ "normalized": false,
312
+ "rstrip": false,
313
+ "single_word": false,
314
+ "special": true
315
+ },
316
+ "32036": {
317
+ "content": "<extra_id_63>",
318
+ "lstrip": false,
319
+ "normalized": false,
320
+ "rstrip": false,
321
+ "single_word": false,
322
+ "special": true
323
+ },
324
+ "32037": {
325
+ "content": "<extra_id_62>",
326
+ "lstrip": false,
327
+ "normalized": false,
328
+ "rstrip": false,
329
+ "single_word": false,
330
+ "special": true
331
+ },
332
+ "32038": {
333
+ "content": "<extra_id_61>",
334
+ "lstrip": false,
335
+ "normalized": false,
336
+ "rstrip": false,
337
+ "single_word": false,
338
+ "special": true
339
+ },
340
+ "32039": {
341
+ "content": "<extra_id_60>",
342
+ "lstrip": false,
343
+ "normalized": false,
344
+ "rstrip": false,
345
+ "single_word": false,
346
+ "special": true
347
+ },
348
+ "32040": {
349
+ "content": "<extra_id_59>",
350
+ "lstrip": false,
351
+ "normalized": false,
352
+ "rstrip": false,
353
+ "single_word": false,
354
+ "special": true
355
+ },
356
+ "32041": {
357
+ "content": "<extra_id_58>",
358
+ "lstrip": false,
359
+ "normalized": false,
360
+ "rstrip": false,
361
+ "single_word": false,
362
+ "special": true
363
+ },
364
+ "32042": {
365
+ "content": "<extra_id_57>",
366
+ "lstrip": false,
367
+ "normalized": false,
368
+ "rstrip": false,
369
+ "single_word": false,
370
+ "special": true
371
+ },
372
+ "32043": {
373
+ "content": "<extra_id_56>",
374
+ "lstrip": false,
375
+ "normalized": false,
376
+ "rstrip": false,
377
+ "single_word": false,
378
+ "special": true
379
+ },
380
+ "32044": {
381
+ "content": "<extra_id_55>",
382
+ "lstrip": false,
383
+ "normalized": false,
384
+ "rstrip": false,
385
+ "single_word": false,
386
+ "special": true
387
+ },
388
+ "32045": {
389
+ "content": "<extra_id_54>",
390
+ "lstrip": false,
391
+ "normalized": false,
392
+ "rstrip": false,
393
+ "single_word": false,
394
+ "special": true
395
+ },
396
+ "32046": {
397
+ "content": "<extra_id_53>",
398
+ "lstrip": false,
399
+ "normalized": false,
400
+ "rstrip": false,
401
+ "single_word": false,
402
+ "special": true
403
+ },
404
+ "32047": {
405
+ "content": "<extra_id_52>",
406
+ "lstrip": false,
407
+ "normalized": false,
408
+ "rstrip": false,
409
+ "single_word": false,
410
+ "special": true
411
+ },
412
+ "32048": {
413
+ "content": "<extra_id_51>",
414
+ "lstrip": false,
415
+ "normalized": false,
416
+ "rstrip": false,
417
+ "single_word": false,
418
+ "special": true
419
+ },
420
+ "32049": {
421
+ "content": "<extra_id_50>",
422
+ "lstrip": false,
423
+ "normalized": false,
424
+ "rstrip": false,
425
+ "single_word": false,
426
+ "special": true
427
+ },
428
+ "32050": {
429
+ "content": "<extra_id_49>",
430
+ "lstrip": false,
431
+ "normalized": false,
432
+ "rstrip": false,
433
+ "single_word": false,
434
+ "special": true
435
+ },
436
+ "32051": {
437
+ "content": "<extra_id_48>",
438
+ "lstrip": false,
439
+ "normalized": false,
440
+ "rstrip": false,
441
+ "single_word": false,
442
+ "special": true
443
+ },
444
+ "32052": {
445
+ "content": "<extra_id_47>",
446
+ "lstrip": false,
447
+ "normalized": false,
448
+ "rstrip": false,
449
+ "single_word": false,
450
+ "special": true
451
+ },
452
+ "32053": {
453
+ "content": "<extra_id_46>",
454
+ "lstrip": false,
455
+ "normalized": false,
456
+ "rstrip": false,
457
+ "single_word": false,
458
+ "special": true
459
+ },
460
+ "32054": {
461
+ "content": "<extra_id_45>",
462
+ "lstrip": false,
463
+ "normalized": false,
464
+ "rstrip": false,
465
+ "single_word": false,
466
+ "special": true
467
+ },
468
+ "32055": {
469
+ "content": "<extra_id_44>",
470
+ "lstrip": false,
471
+ "normalized": false,
472
+ "rstrip": false,
473
+ "single_word": false,
474
+ "special": true
475
+ },
476
+ "32056": {
477
+ "content": "<extra_id_43>",
478
+ "lstrip": false,
479
+ "normalized": false,
480
+ "rstrip": false,
481
+ "single_word": false,
482
+ "special": true
483
+ },
484
+ "32057": {
485
+ "content": "<extra_id_42>",
486
+ "lstrip": false,
487
+ "normalized": false,
488
+ "rstrip": false,
489
+ "single_word": false,
490
+ "special": true
491
+ },
492
+ "32058": {
493
+ "content": "<extra_id_41>",
494
+ "lstrip": false,
495
+ "normalized": false,
496
+ "rstrip": false,
497
+ "single_word": false,
498
+ "special": true
499
+ },
500
+ "32059": {
501
+ "content": "<extra_id_40>",
502
+ "lstrip": false,
503
+ "normalized": false,
504
+ "rstrip": false,
505
+ "single_word": false,
506
+ "special": true
507
+ },
508
+ "32060": {
509
+ "content": "<extra_id_39>",
510
+ "lstrip": false,
511
+ "normalized": false,
512
+ "rstrip": false,
513
+ "single_word": false,
514
+ "special": true
515
+ },
516
+ "32061": {
517
+ "content": "<extra_id_38>",
518
+ "lstrip": false,
519
+ "normalized": false,
520
+ "rstrip": false,
521
+ "single_word": false,
522
+ "special": true
523
+ },
524
+ "32062": {
525
+ "content": "<extra_id_37>",
526
+ "lstrip": false,
527
+ "normalized": false,
528
+ "rstrip": false,
529
+ "single_word": false,
530
+ "special": true
531
+ },
532
+ "32063": {
533
+ "content": "<extra_id_36>",
534
+ "lstrip": false,
535
+ "normalized": false,
536
+ "rstrip": false,
537
+ "single_word": false,
538
+ "special": true
539
+ },
540
+ "32064": {
541
+ "content": "<extra_id_35>",
542
+ "lstrip": false,
543
+ "normalized": false,
544
+ "rstrip": false,
545
+ "single_word": false,
546
+ "special": true
547
+ },
548
+ "32065": {
549
+ "content": "<extra_id_34>",
550
+ "lstrip": false,
551
+ "normalized": false,
552
+ "rstrip": false,
553
+ "single_word": false,
554
+ "special": true
555
+ },
556
+ "32066": {
557
+ "content": "<extra_id_33>",
558
+ "lstrip": false,
559
+ "normalized": false,
560
+ "rstrip": false,
561
+ "single_word": false,
562
+ "special": true
563
+ },
564
+ "32067": {
565
+ "content": "<extra_id_32>",
566
+ "lstrip": false,
567
+ "normalized": false,
568
+ "rstrip": false,
569
+ "single_word": false,
570
+ "special": true
571
+ },
572
+ "32068": {
573
+ "content": "<extra_id_31>",
574
+ "lstrip": false,
575
+ "normalized": false,
576
+ "rstrip": false,
577
+ "single_word": false,
578
+ "special": true
579
+ },
580
+ "32069": {
581
+ "content": "<extra_id_30>",
582
+ "lstrip": false,
583
+ "normalized": false,
584
+ "rstrip": false,
585
+ "single_word": false,
586
+ "special": true
587
+ },
588
+ "32070": {
589
+ "content": "<extra_id_29>",
590
+ "lstrip": false,
591
+ "normalized": false,
592
+ "rstrip": false,
593
+ "single_word": false,
594
+ "special": true
595
+ },
596
+ "32071": {
597
+ "content": "<extra_id_28>",
598
+ "lstrip": false,
599
+ "normalized": false,
600
+ "rstrip": false,
601
+ "single_word": false,
602
+ "special": true
603
+ },
604
+ "32072": {
605
+ "content": "<extra_id_27>",
606
+ "lstrip": false,
607
+ "normalized": false,
608
+ "rstrip": false,
609
+ "single_word": false,
610
+ "special": true
611
+ },
612
+ "32073": {
613
+ "content": "<extra_id_26>",
614
+ "lstrip": false,
615
+ "normalized": false,
616
+ "rstrip": false,
617
+ "single_word": false,
618
+ "special": true
619
+ },
620
+ "32074": {
621
+ "content": "<extra_id_25>",
622
+ "lstrip": false,
623
+ "normalized": false,
624
+ "rstrip": false,
625
+ "single_word": false,
626
+ "special": true
627
+ },
628
+ "32075": {
629
+ "content": "<extra_id_24>",
630
+ "lstrip": false,
631
+ "normalized": false,
632
+ "rstrip": false,
633
+ "single_word": false,
634
+ "special": true
635
+ },
636
+ "32076": {
637
+ "content": "<extra_id_23>",
638
+ "lstrip": false,
639
+ "normalized": false,
640
+ "rstrip": false,
641
+ "single_word": false,
642
+ "special": true
643
+ },
644
+ "32077": {
645
+ "content": "<extra_id_22>",
646
+ "lstrip": false,
647
+ "normalized": false,
648
+ "rstrip": false,
649
+ "single_word": false,
650
+ "special": true
651
+ },
652
+ "32078": {
653
+ "content": "<extra_id_21>",
654
+ "lstrip": false,
655
+ "normalized": false,
656
+ "rstrip": false,
657
+ "single_word": false,
658
+ "special": true
659
+ },
660
+ "32079": {
661
+ "content": "<extra_id_20>",
662
+ "lstrip": false,
663
+ "normalized": false,
664
+ "rstrip": false,
665
+ "single_word": false,
666
+ "special": true
667
+ },
668
+ "32080": {
669
+ "content": "<extra_id_19>",
670
+ "lstrip": false,
671
+ "normalized": false,
672
+ "rstrip": false,
673
+ "single_word": false,
674
+ "special": true
675
+ },
676
+ "32081": {
677
+ "content": "<extra_id_18>",
678
+ "lstrip": false,
679
+ "normalized": false,
680
+ "rstrip": false,
681
+ "single_word": false,
682
+ "special": true
683
+ },
684
+ "32082": {
685
+ "content": "<extra_id_17>",
686
+ "lstrip": false,
687
+ "normalized": false,
688
+ "rstrip": false,
689
+ "single_word": false,
690
+ "special": true
691
+ },
692
+ "32083": {
693
+ "content": "<extra_id_16>",
694
+ "lstrip": false,
695
+ "normalized": false,
696
+ "rstrip": false,
697
+ "single_word": false,
698
+ "special": true
699
+ },
700
+ "32084": {
701
+ "content": "<extra_id_15>",
702
+ "lstrip": false,
703
+ "normalized": false,
704
+ "rstrip": false,
705
+ "single_word": false,
706
+ "special": true
707
+ },
708
+ "32085": {
709
+ "content": "<extra_id_14>",
710
+ "lstrip": false,
711
+ "normalized": false,
712
+ "rstrip": false,
713
+ "single_word": false,
714
+ "special": true
715
+ },
716
+ "32086": {
717
+ "content": "<extra_id_13>",
718
+ "lstrip": false,
719
+ "normalized": false,
720
+ "rstrip": false,
721
+ "single_word": false,
722
+ "special": true
723
+ },
724
+ "32087": {
725
+ "content": "<extra_id_12>",
726
+ "lstrip": false,
727
+ "normalized": false,
728
+ "rstrip": false,
729
+ "single_word": false,
730
+ "special": true
731
+ },
732
+ "32088": {
733
+ "content": "<extra_id_11>",
734
+ "lstrip": false,
735
+ "normalized": false,
736
+ "rstrip": false,
737
+ "single_word": false,
738
+ "special": true
739
+ },
740
+ "32089": {
741
+ "content": "<extra_id_10>",
742
+ "lstrip": false,
743
+ "normalized": false,
744
+ "rstrip": false,
745
+ "single_word": false,
746
+ "special": true
747
+ },
748
+ "32090": {
749
+ "content": "<extra_id_9>",
750
+ "lstrip": false,
751
+ "normalized": false,
752
+ "rstrip": false,
753
+ "single_word": false,
754
+ "special": true
755
+ },
756
+ "32091": {
757
+ "content": "<extra_id_8>",
758
+ "lstrip": false,
759
+ "normalized": false,
760
+ "rstrip": false,
761
+ "single_word": false,
762
+ "special": true
763
+ },
764
+ "32092": {
765
+ "content": "<extra_id_7>",
766
+ "lstrip": false,
767
+ "normalized": false,
768
+ "rstrip": false,
769
+ "single_word": false,
770
+ "special": true
771
+ },
772
+ "32093": {
773
+ "content": "<extra_id_6>",
774
+ "lstrip": false,
775
+ "normalized": false,
776
+ "rstrip": false,
777
+ "single_word": false,
778
+ "special": true
779
+ },
780
+ "32094": {
781
+ "content": "<extra_id_5>",
782
+ "lstrip": false,
783
+ "normalized": false,
784
+ "rstrip": false,
785
+ "single_word": false,
786
+ "special": true
787
+ },
788
+ "32095": {
789
+ "content": "<extra_id_4>",
790
+ "lstrip": false,
791
+ "normalized": false,
792
+ "rstrip": false,
793
+ "single_word": false,
794
+ "special": true
795
+ },
796
+ "32096": {
797
+ "content": "<extra_id_3>",
798
+ "lstrip": false,
799
+ "normalized": false,
800
+ "rstrip": false,
801
+ "single_word": false,
802
+ "special": true
803
+ },
804
+ "32097": {
805
+ "content": "<extra_id_2>",
806
+ "lstrip": false,
807
+ "normalized": false,
808
+ "rstrip": false,
809
+ "single_word": false,
810
+ "special": true
811
+ },
812
+ "32098": {
813
+ "content": "<extra_id_1>",
814
+ "lstrip": false,
815
+ "normalized": false,
816
+ "rstrip": false,
817
+ "single_word": false,
818
+ "special": true
819
+ },
820
+ "32099": {
821
+ "content": "<extra_id_0>",
822
+ "lstrip": false,
823
+ "normalized": false,
824
+ "rstrip": false,
825
+ "single_word": false,
826
+ "special": true
827
+ }
828
+ },
829
+ "additional_special_tokens": [
830
+ "<extra_id_0>",
831
+ "<extra_id_1>",
832
+ "<extra_id_2>",
833
+ "<extra_id_3>",
834
+ "<extra_id_4>",
835
+ "<extra_id_5>",
836
+ "<extra_id_6>",
837
+ "<extra_id_7>",
838
+ "<extra_id_8>",
839
+ "<extra_id_9>",
840
+ "<extra_id_10>",
841
+ "<extra_id_11>",
842
+ "<extra_id_12>",
843
+ "<extra_id_13>",
844
+ "<extra_id_14>",
845
+ "<extra_id_15>",
846
+ "<extra_id_16>",
847
+ "<extra_id_17>",
848
+ "<extra_id_18>",
849
+ "<extra_id_19>",
850
+ "<extra_id_20>",
851
+ "<extra_id_21>",
852
+ "<extra_id_22>",
853
+ "<extra_id_23>",
854
+ "<extra_id_24>",
855
+ "<extra_id_25>",
856
+ "<extra_id_26>",
857
+ "<extra_id_27>",
858
+ "<extra_id_28>",
859
+ "<extra_id_29>",
860
+ "<extra_id_30>",
861
+ "<extra_id_31>",
862
+ "<extra_id_32>",
863
+ "<extra_id_33>",
864
+ "<extra_id_34>",
865
+ "<extra_id_35>",
866
+ "<extra_id_36>",
867
+ "<extra_id_37>",
868
+ "<extra_id_38>",
869
+ "<extra_id_39>",
870
+ "<extra_id_40>",
871
+ "<extra_id_41>",
872
+ "<extra_id_42>",
873
+ "<extra_id_43>",
874
+ "<extra_id_44>",
875
+ "<extra_id_45>",
876
+ "<extra_id_46>",
877
+ "<extra_id_47>",
878
+ "<extra_id_48>",
879
+ "<extra_id_49>",
880
+ "<extra_id_50>",
881
+ "<extra_id_51>",
882
+ "<extra_id_52>",
883
+ "<extra_id_53>",
884
+ "<extra_id_54>",
885
+ "<extra_id_55>",
886
+ "<extra_id_56>",
887
+ "<extra_id_57>",
888
+ "<extra_id_58>",
889
+ "<extra_id_59>",
890
+ "<extra_id_60>",
891
+ "<extra_id_61>",
892
+ "<extra_id_62>",
893
+ "<extra_id_63>",
894
+ "<extra_id_64>",
895
+ "<extra_id_65>",
896
+ "<extra_id_66>",
897
+ "<extra_id_67>",
898
+ "<extra_id_68>",
899
+ "<extra_id_69>",
900
+ "<extra_id_70>",
901
+ "<extra_id_71>",
902
+ "<extra_id_72>",
903
+ "<extra_id_73>",
904
+ "<extra_id_74>",
905
+ "<extra_id_75>",
906
+ "<extra_id_76>",
907
+ "<extra_id_77>",
908
+ "<extra_id_78>",
909
+ "<extra_id_79>",
910
+ "<extra_id_80>",
911
+ "<extra_id_81>",
912
+ "<extra_id_82>",
913
+ "<extra_id_83>",
914
+ "<extra_id_84>",
915
+ "<extra_id_85>",
916
+ "<extra_id_86>",
917
+ "<extra_id_87>",
918
+ "<extra_id_88>",
919
+ "<extra_id_89>",
920
+ "<extra_id_90>",
921
+ "<extra_id_91>",
922
+ "<extra_id_92>",
923
+ "<extra_id_93>",
924
+ "<extra_id_94>",
925
+ "<extra_id_95>",
926
+ "<extra_id_96>",
927
+ "<extra_id_97>",
928
+ "<extra_id_98>",
929
+ "<extra_id_99>"
930
+ ],
931
+ "clean_up_tokenization_spaces": true,
932
+ "eos_token": "</s>",
933
+ "extra_ids": 100,
934
+ "legacy": true,
935
+ "model_max_length": 512,
936
+ "pad_token": "<pad>",
937
+ "sp_model_kwargs": {},
938
+ "tokenizer_class": "T5Tokenizer",
939
+ "unk_token": "<unk>"
940
+ }
transformer/config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "LibreFluxTransformer2DModel",
3
+ "_diffusers_version": "0.35.2",
4
+ "_name_or_path": "/path/to/transformer",
5
+ "attention_head_dim": 128,
6
+ "axes_dims_rope": [
7
+ 16,
8
+ 56,
9
+ 56
10
+ ],
11
+ "guidance_embeds": false,
12
+ "in_channels": 64,
13
+ "joint_attention_dim": 4096,
14
+ "num_attention_heads": 24,
15
+ "num_layers": 19,
16
+ "num_single_layers": 38,
17
+ "patch_size": 1,
18
+ "pooled_projection_dim": 768
19
+ }
transformer/diffusion_pytorch_model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18c2abe01a326d95bc836cfd5f68167118c0ecb2c8ccbcf5d6de4dbad47ca53c
3
+ size 9962580296
transformer/diffusion_pytorch_model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:828f131e306b17535c8c1d0a3c4aaa06f2a60a80612500da229829242f3ed422
3
+ size 9949328904
transformer/diffusion_pytorch_model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a17988ef4255372dd902ff5742e647f8a60dcc83756d740d1fbcf81d13d38162
3
+ size 3870584832
transformer/diffusion_pytorch_model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
transformer/trans.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, The HuggingFace Team, The InstantX Team, and Terminus Research Group. All rights reserved.
2
+ #
3
+ # Originally licensed under the Apache License, Version 2.0 (the "License");
4
+ # Updated to "Affero GENERAL PUBLIC LICENSE Version 3, 19 November 2007" via extensive updates to attn_mask usage.
5
+ # This was taken from SimpleTuner and modified as needed
6
+
7
+ from typing import Any, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
16
+ from diffusers.models.attention import FeedForward
17
+ from diffusers.models.attention_processor import (
18
+ Attention,
19
+ AttentionProcessor,
20
+ )
21
+ from diffusers.models.modeling_utils import ModelMixin
22
+ from diffusers.models.normalization import (
23
+ AdaLayerNormContinuous,
24
+ AdaLayerNormZero,
25
+ AdaLayerNormZeroSingle,
26
+ )
27
+ from diffusers.utils import (
28
+ USE_PEFT_BACKEND,
29
+ is_torch_version,
30
+ logging,
31
+ scale_lora_layers,
32
+ unscale_lora_layers,
33
+ )
34
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
35
+ from diffusers.models.embeddings import (
36
+ CombinedTimestepGuidanceTextProjEmbeddings,
37
+ CombinedTimestepTextProjEmbeddings,
38
+ FluxPosEmbed,
39
+ )
40
+
41
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
42
+ from diffusers import FluxTransformer2DModel as OriginalFluxTransformer2DModel
43
+
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+
48
+
49
+ class FluxAttnProcessor2_0:
50
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
51
+
52
+ def __init__(self):
53
+ if not hasattr(F, "scaled_dot_product_attention"):
54
+ raise ImportError(
55
+ "FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
56
+ )
57
+
58
+ def __call__(
59
+ self,
60
+ attn: Attention,
61
+ hidden_states: torch.FloatTensor,
62
+ encoder_hidden_states: torch.FloatTensor = None,
63
+ attention_mask: Optional[torch.FloatTensor] = None,
64
+ image_rotary_emb: Optional[torch.Tensor] = None,
65
+ ip_encoder_hidden_states:Optional[torch.Tensor] = None,
66
+ layer_scale:Optional[torch.Tensor] = None,
67
+ ) -> torch.FloatTensor:
68
+ batch_size, _, _ = (
69
+ hidden_states.shape
70
+ if encoder_hidden_states is None
71
+ else encoder_hidden_states.shape
72
+ )
73
+
74
+ # `sample` projections.
75
+ query = attn.to_q(hidden_states)
76
+ key = attn.to_k(hidden_states)
77
+ value = attn.to_v(hidden_states)
78
+
79
+ inner_dim = key.shape[-1]
80
+ head_dim = inner_dim // attn.heads
81
+
82
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
83
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
84
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
85
+
86
+ if attn.norm_q is not None:
87
+ query = attn.norm_q(query)
88
+ if attn.norm_k is not None:
89
+ key = attn.norm_k(key)
90
+
91
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
92
+ if encoder_hidden_states is not None:
93
+ # `context` projections.
94
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
95
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
96
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
97
+
98
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
99
+ batch_size, -1, attn.heads, head_dim
100
+ ).transpose(1, 2)
101
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
102
+ batch_size, -1, attn.heads, head_dim
103
+ ).transpose(1, 2)
104
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
105
+ batch_size, -1, attn.heads, head_dim
106
+ ).transpose(1, 2)
107
+
108
+ if attn.norm_added_q is not None:
109
+ encoder_hidden_states_query_proj = attn.norm_added_q(
110
+ encoder_hidden_states_query_proj
111
+ )
112
+ if attn.norm_added_k is not None:
113
+ encoder_hidden_states_key_proj = attn.norm_added_k(
114
+ encoder_hidden_states_key_proj
115
+ )
116
+
117
+ # attention
118
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
119
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
120
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
121
+
122
+ if image_rotary_emb is not None:
123
+ from diffusers.models.embeddings import apply_rotary_emb
124
+
125
+ query = apply_rotary_emb(query, image_rotary_emb)
126
+ key = apply_rotary_emb(key, image_rotary_emb)
127
+
128
+ if attention_mask is not None:
129
+ #print ('Attention Used')
130
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
131
+ attention_mask = (attention_mask > 0).bool()
132
+ # Edit 17 - match attn dtype to query d-type
133
+ attention_mask = attention_mask.to(
134
+ device=hidden_states.device, dtype=query.dtype
135
+ )
136
+
137
+ hidden_states = F.scaled_dot_product_attention(
138
+ query,
139
+ key,
140
+ value,
141
+ dropout_p=0.0,
142
+ is_causal=False,
143
+ attn_mask=attention_mask,
144
+ )
145
+ hidden_states = hidden_states.transpose(1, 2).reshape(
146
+ batch_size, -1, attn.heads * head_dim
147
+ )
148
+ hidden_states = hidden_states.to(query.dtype)
149
+
150
+ if encoder_hidden_states is not None:
151
+ encoder_hidden_states, hidden_states = (
152
+ hidden_states[:, : encoder_hidden_states.shape[1]],
153
+ hidden_states[:, encoder_hidden_states.shape[1] :],
154
+ )
155
+
156
+ # linear proj
157
+ hidden_states = attn.to_out[0](hidden_states)
158
+ # dropout
159
+ hidden_states = attn.to_out[1](hidden_states)
160
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
161
+
162
+ return hidden_states, encoder_hidden_states
163
+ return hidden_states
164
+
165
+
166
+ def expand_flux_attention_mask(
167
+ hidden_states: torch.Tensor,
168
+ attn_mask: torch.Tensor,
169
+ ) -> torch.Tensor:
170
+ """
171
+ Expand a mask so that the image is included.
172
+ """
173
+ bsz = attn_mask.shape[0]
174
+ assert bsz == hidden_states.shape[0]
175
+ residual_seq_len = hidden_states.shape[1]
176
+ mask_seq_len = attn_mask.shape[1]
177
+
178
+ expanded_mask = torch.ones(bsz, residual_seq_len)
179
+ expanded_mask[:, :mask_seq_len] = attn_mask
180
+
181
+ return expanded_mask
182
+
183
+
184
+ @maybe_allow_in_graph
185
+ class FluxSingleTransformerBlock(nn.Module):
186
+ r"""
187
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
188
+
189
+ Reference: https://arxiv.org/abs/2403.03206
190
+
191
+ Parameters:
192
+ dim (`int`): The number of channels in the input and output.
193
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
194
+ attention_head_dim (`int`): The number of channels in each head.
195
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
196
+ processing of `context` conditions.
197
+ """
198
+
199
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
200
+ super().__init__()
201
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
202
+
203
+ self.norm = AdaLayerNormZeroSingle(dim)
204
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
205
+ self.act_mlp = nn.GELU(approximate="tanh")
206
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
207
+
208
+ processor = FluxAttnProcessor2_0()
209
+ self.attn = Attention(
210
+ query_dim=dim,
211
+ cross_attention_dim=None,
212
+ dim_head=attention_head_dim,
213
+ heads=num_attention_heads,
214
+ out_dim=dim,
215
+ bias=True,
216
+ processor=processor,
217
+ qk_norm="rms_norm",
218
+ eps=1e-6,
219
+ pre_only=True,
220
+ )
221
+
222
+ def forward(
223
+ self,
224
+ hidden_states: torch.FloatTensor,
225
+ temb: torch.FloatTensor,
226
+ image_rotary_emb=None,
227
+ attention_mask: Optional[torch.Tensor] = None,
228
+ joint_attention_kwargs: dict = {},
229
+ ):
230
+ residual = hidden_states
231
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
232
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
233
+
234
+ if attention_mask is not None:
235
+ attention_mask = expand_flux_attention_mask(
236
+ hidden_states,
237
+ attention_mask,
238
+ )
239
+
240
+ # Adding ability to pass hidden state info to IP Adapter
241
+ ip_encoder_hidden_states = None
242
+ ip_layer_scale = 1.0
243
+ if 'ip_hidden_states' in joint_attention_kwargs:
244
+ ip_encoder_hidden_states = joint_attention_kwargs['ip_hidden_states']
245
+ if 'ip_layer_scale' in joint_attention_kwargs:
246
+ ip_layer_scale = joint_attention_kwargs['ip_layer_scale']
247
+
248
+ # Attention.
249
+ attn_output = self.attn(
250
+ hidden_states=norm_hidden_states,
251
+ image_rotary_emb=image_rotary_emb,
252
+ attention_mask=attention_mask,
253
+ ip_encoder_hidden_states=ip_encoder_hidden_states,
254
+ layer_scale=ip_layer_scale,
255
+ )
256
+ #attn_output = self.attn(
257
+ # hidden_states=norm_hidden_states,
258
+ # image_rotary_emb=image_rotary_emb,
259
+ # attention_mask=attention_mask,
260
+ #)
261
+
262
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
263
+ gate = gate.unsqueeze(1)
264
+ hidden_states = gate * self.proj_out(hidden_states)
265
+ hidden_states = residual + hidden_states
266
+
267
+ if hidden_states.dtype == torch.float16:
268
+ hidden_states = hidden_states.clip(-65504, 65504)
269
+
270
+ return hidden_states
271
+
272
+
273
+ @maybe_allow_in_graph
274
+ class FluxTransformerBlock(nn.Module):
275
+ r"""
276
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
277
+
278
+ Reference: https://arxiv.org/abs/2403.03206
279
+
280
+ Parameters:
281
+ dim (`int`): The number of channels in the input and output.
282
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
283
+ attention_head_dim (`int`): The number of channels in each head.
284
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
285
+ processing of `context` conditions.
286
+ """
287
+
288
+ def __init__(
289
+ self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6
290
+ ):
291
+ super().__init__()
292
+
293
+ self.norm1 = AdaLayerNormZero(dim)
294
+
295
+ self.norm1_context = AdaLayerNormZero(dim)
296
+
297
+ if hasattr(F, "scaled_dot_product_attention"):
298
+ processor = FluxAttnProcessor2_0()
299
+ else:
300
+ raise ValueError(
301
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
302
+ )
303
+ self.attn = Attention(
304
+ query_dim=dim,
305
+ cross_attention_dim=None,
306
+ added_kv_proj_dim=dim,
307
+ dim_head=attention_head_dim,
308
+ heads=num_attention_heads,
309
+ out_dim=dim,
310
+ context_pre_only=False,
311
+ bias=True,
312
+ processor=processor,
313
+ qk_norm=qk_norm,
314
+ eps=eps,
315
+ )
316
+
317
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
318
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
319
+
320
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
321
+ self.ff_context = FeedForward(
322
+ dim=dim, dim_out=dim, activation_fn="gelu-approximate"
323
+ )
324
+
325
+ # let chunk size default to None
326
+ self._chunk_size = None
327
+ self._chunk_dim = 0
328
+
329
+ def forward(
330
+ self,
331
+ hidden_states: torch.FloatTensor,
332
+ encoder_hidden_states: torch.FloatTensor,
333
+ temb: torch.FloatTensor,
334
+ image_rotary_emb=None,
335
+ attention_mask: Optional[torch.Tensor] = None,
336
+ joint_attention_kwargs: dict = {},
337
+ ):
338
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
339
+ hidden_states, emb=temb
340
+ )
341
+
342
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
343
+ self.norm1_context(encoder_hidden_states, emb=temb)
344
+ )
345
+
346
+ if attention_mask is not None:
347
+ attention_mask = expand_flux_attention_mask(
348
+ torch.cat([encoder_hidden_states, hidden_states], dim=1),
349
+ attention_mask,
350
+ )
351
+
352
+ # Adding ability to pass hidden state info to IP Adapter
353
+ ip_encoder_hidden_states = None
354
+ ip_layer_scale = 1.0
355
+ if 'ip_hidden_states' in joint_attention_kwargs:
356
+ ip_encoder_hidden_states = joint_attention_kwargs['ip_hidden_states']
357
+ if 'ip_layer_scale' in joint_attention_kwargs:
358
+ ip_layer_scale = joint_attention_kwargs['ip_layer_scale']
359
+
360
+ # Attention.
361
+ attention_outputs = self.attn(
362
+ hidden_states=norm_hidden_states,
363
+ encoder_hidden_states=norm_encoder_hidden_states,
364
+ image_rotary_emb=image_rotary_emb,
365
+ attention_mask=attention_mask,
366
+ ip_encoder_hidden_states=ip_encoder_hidden_states,
367
+ layer_scale=ip_layer_scale,
368
+ )
369
+ if len(attention_outputs) == 2:
370
+ attn_output, context_attn_output = attention_outputs
371
+ elif len(attention_outputs) == 3:
372
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
373
+
374
+ # Process attention outputs for the `hidden_states`.
375
+ attn_output = gate_msa.unsqueeze(1) * attn_output
376
+ hidden_states = hidden_states + attn_output
377
+
378
+ norm_hidden_states = self.norm2(hidden_states)
379
+ norm_hidden_states = (
380
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
381
+ )
382
+
383
+ ff_output = self.ff(norm_hidden_states)
384
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
385
+
386
+ hidden_states = hidden_states + ff_output
387
+
388
+ # Removing this, was casuing error after adding ip_adapter
389
+ #if len(attention_outputs) == 3:
390
+ # hidden_states = hidden_states + ip_attn_output
391
+
392
+ # Process attention outputs for the `encoder_hidden_states`.
393
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
394
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
395
+
396
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
397
+ norm_encoder_hidden_states = (
398
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
399
+ + c_shift_mlp[:, None]
400
+ )
401
+
402
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
403
+ encoder_hidden_states = (
404
+ encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
405
+ )
406
+
407
+ if encoder_hidden_states.dtype == torch.float16:
408
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
409
+
410
+ return encoder_hidden_states, hidden_states
411
+
412
+ class LibreFluxTransformer2DModel(
413
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
414
+ ):
415
+ """
416
+ The Transformer model introduced in Flux.
417
+
418
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
419
+
420
+ Parameters:
421
+ patch_size (`int`): Patch size to turn the input data into small patches.
422
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
423
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
424
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
425
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
426
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
427
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
428
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
429
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
430
+ """
431
+
432
+ _supports_gradient_checkpointing = True
433
+
434
+ @register_to_config
435
+ def __init__(
436
+ self,
437
+ patch_size: int = 1,
438
+ in_channels: int = 64,
439
+ num_layers: int = 19,
440
+ num_single_layers: int = 38,
441
+ attention_head_dim: int = 128,
442
+ num_attention_heads: int = 24,
443
+ joint_attention_dim: int = 4096,
444
+ pooled_projection_dim: int = 768,
445
+ guidance_embeds: bool = False,
446
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
447
+ ):
448
+ super().__init__()
449
+ self.out_channels = in_channels
450
+ self.inner_dim = (
451
+ self.config.num_attention_heads * self.config.attention_head_dim
452
+ )
453
+
454
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
455
+ text_time_guidance_cls = (
456
+ CombinedTimestepGuidanceTextProjEmbeddings ### 3 input forward (timestep, guidance, pooled_projection)
457
+ if guidance_embeds
458
+ else CombinedTimestepTextProjEmbeddings #### 2 input forward (timestep, pooled_projection)
459
+ )
460
+ self.time_text_embed = text_time_guidance_cls(
461
+ embedding_dim=self.inner_dim,
462
+ pooled_projection_dim=self.config.pooled_projection_dim,
463
+ )
464
+
465
+ self.context_embedder = nn.Linear(
466
+ self.config.joint_attention_dim, self.inner_dim
467
+ )
468
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
469
+
470
+ self.transformer_blocks = nn.ModuleList(
471
+ [
472
+ FluxTransformerBlock(
473
+ dim=self.inner_dim,
474
+ num_attention_heads=self.config.num_attention_heads,
475
+ attention_head_dim=self.config.attention_head_dim,
476
+ )
477
+ for i in range(self.config.num_layers)
478
+ ]
479
+ )
480
+
481
+ self.single_transformer_blocks = nn.ModuleList(
482
+ [
483
+ FluxSingleTransformerBlock(
484
+ dim=self.inner_dim,
485
+ num_attention_heads=self.config.num_attention_heads,
486
+ attention_head_dim=self.config.attention_head_dim,
487
+ )
488
+ for i in range(self.config.num_single_layers)
489
+ ]
490
+ )
491
+
492
+ self.norm_out = AdaLayerNormContinuous(
493
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
494
+ )
495
+ self.proj_out = nn.Linear(
496
+ self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
497
+ )
498
+
499
+ self.gradient_checkpointing = False
500
+ # added for users to disable checkpointing every nth step
501
+ self.gradient_checkpointing_interval = None
502
+
503
+ def set_gradient_checkpointing_interval(self, value: int):
504
+ self.gradient_checkpointing_interval = value
505
+
506
+ @property
507
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
508
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
509
+ r"""
510
+ Returns:
511
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
512
+ indexed by its weight name.
513
+ """
514
+ # set recursively
515
+ processors = {}
516
+
517
+ def fn_recursive_add_processors(
518
+ name: str,
519
+ module: torch.nn.Module,
520
+ processors: Dict[str, AttentionProcessor],
521
+ ):
522
+ if hasattr(module, "get_processor"):
523
+ processors[f"{name}.processor"] = module.get_processor()
524
+
525
+ for sub_name, child in module.named_children():
526
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
527
+
528
+ return processors
529
+
530
+ for name, module in self.named_children():
531
+ fn_recursive_add_processors(name, module, processors)
532
+
533
+ return processors
534
+
535
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
536
+ def set_attn_processor(
537
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
538
+ ):
539
+ r"""
540
+ Sets the attention processor to use to compute attention.
541
+
542
+ Parameters:
543
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
544
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
545
+ for **all** `Attention` layers.
546
+
547
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
548
+ processor. This is strongly recommended when setting trainable attention processors.
549
+
550
+ """
551
+ count = len(self.attn_processors.keys())
552
+
553
+ if isinstance(processor, dict) and len(processor) != count:
554
+ raise ValueError(
555
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
556
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
557
+ )
558
+
559
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
560
+ if hasattr(module, "set_processor"):
561
+ if not isinstance(processor, dict):
562
+ module.set_processor(processor)
563
+ else:
564
+ module.set_processor(processor.pop(f"{name}.processor"))
565
+
566
+ for sub_name, child in module.named_children():
567
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
568
+
569
+ for name, module in self.named_children():
570
+ fn_recursive_attn_processor(name, module, processor)
571
+
572
+ def forward(
573
+ self,
574
+ hidden_states: torch.Tensor,
575
+ encoder_hidden_states: torch.Tensor = None,
576
+ pooled_projections: torch.Tensor = None,
577
+ timestep: torch.LongTensor = None,
578
+ img_ids: torch.Tensor = None,
579
+ txt_ids: torch.Tensor = None,
580
+ guidance: torch.Tensor = None,
581
+ joint_attention_kwargs: dict = {},
582
+ controlnet_block_samples=None,
583
+ controlnet_single_block_samples=None,
584
+ return_dict: bool = True,
585
+ attention_mask: Optional[torch.Tensor] = None,
586
+ controlnet_blocks_repeat: bool = False,
587
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
588
+ """
589
+ The [`FluxTransformer2DModel`] forward method.
590
+
591
+ Args:
592
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
593
+ Input `hidden_states`.
594
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
595
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
596
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
597
+ from the embeddings of input conditions.
598
+ timestep ( `torch.LongTensor`):
599
+ Used to indicate denoising step.
600
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
601
+ A list of tensors that if specified are added to the residuals of transformer blocks.
602
+ joint_attention_kwargs (`dict`, *optional*):
603
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
604
+ `self.processor` in
605
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
606
+ return_dict (`bool`, *optional*, defaults to `True`):
607
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
608
+ tuple.
609
+
610
+ Returns:
611
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
612
+ `tuple` where the first element is the sample tensor.
613
+ """
614
+ if joint_attention_kwargs is not None:
615
+ joint_attention_kwargs = joint_attention_kwargs.copy()
616
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
617
+ else:
618
+ lora_scale = 1.0
619
+
620
+ if USE_PEFT_BACKEND:
621
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
622
+ scale_lora_layers(self, lora_scale)
623
+ else:
624
+ if (
625
+ joint_attention_kwargs is not None
626
+ and joint_attention_kwargs.get("scale", None) is not None
627
+ ):
628
+ logger.warning(
629
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
630
+ )
631
+ hidden_states = self.x_embedder(hidden_states)
632
+
633
+ timestep = timestep.to(hidden_states.dtype) * 1000
634
+ if guidance is not None:
635
+ guidance = guidance.to(hidden_states.dtype) * 1000
636
+ else:
637
+ guidance = None
638
+
639
+ #print( self.time_text_embed)
640
+ temb = (
641
+ self.time_text_embed(timestep,pooled_projections)
642
+ # Edit 1 # Charlie NOT NEEDED - UNDONE
643
+ if guidance is None
644
+ else self.time_text_embed(timestep, guidance, pooled_projections)
645
+ )
646
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
647
+
648
+ if txt_ids.ndim == 3:
649
+ txt_ids = txt_ids[0]
650
+ if img_ids.ndim == 3:
651
+ img_ids = img_ids[0]
652
+
653
+ ids = torch.cat((txt_ids, img_ids), dim=0)
654
+
655
+ image_rotary_emb = self.pos_embed(ids)
656
+
657
+
658
+ for index_block, block in enumerate(self.transformer_blocks):
659
+ if (
660
+ self.training
661
+ and self.gradient_checkpointing
662
+ and (
663
+ self.gradient_checkpointing_interval is None
664
+ or index_block % self.gradient_checkpointing_interval == 0
665
+ )
666
+ ):
667
+
668
+ def create_custom_forward(module, return_dict=None):
669
+ def custom_forward(*inputs):
670
+ if return_dict is not None:
671
+ return module(*inputs, return_dict=return_dict)
672
+ else:
673
+ return module(*inputs)
674
+
675
+ return custom_forward
676
+
677
+ ckpt_kwargs: Dict[str, Any] = (
678
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
679
+ )
680
+ encoder_hidden_states, hidden_states = (
681
+ torch.utils.checkpoint.checkpoint(
682
+ create_custom_forward(block),
683
+ hidden_states,
684
+ encoder_hidden_states,
685
+ temb,
686
+ image_rotary_emb,
687
+ attention_mask,
688
+ joint_attention_kwargs, # Add this line
689
+ **ckpt_kwargs,
690
+ )
691
+ )
692
+
693
+ else:
694
+ encoder_hidden_states, hidden_states = block(
695
+ hidden_states=hidden_states,
696
+ encoder_hidden_states=encoder_hidden_states,
697
+ temb=temb,
698
+ image_rotary_emb=image_rotary_emb,
699
+ attention_mask=attention_mask,
700
+ joint_attention_kwargs=joint_attention_kwargs # Add this line
701
+ )
702
+
703
+ # controlnet residual
704
+ if controlnet_block_samples is not None:
705
+ interval_control = len(self.transformer_blocks) / len(
706
+ controlnet_block_samples
707
+ )
708
+ interval_control = int(np.ceil(interval_control))
709
+ # For Xlabs ControlNet.
710
+ if controlnet_blocks_repeat:
711
+ hidden_states = (
712
+ hidden_states
713
+ + controlnet_block_samples[
714
+ index_block % len(controlnet_block_samples)
715
+ ]
716
+ )
717
+ else:
718
+ hidden_states = (
719
+ hidden_states
720
+ + controlnet_block_samples[index_block // interval_control]
721
+ )
722
+
723
+ # Flux places the text tokens in front of the image tokens in the
724
+ # sequence.
725
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
726
+
727
+ for index_block, block in enumerate(self.single_transformer_blocks):
728
+ if (
729
+ self.training
730
+ and self.gradient_checkpointing
731
+ or (
732
+ self.gradient_checkpointing_interval is not None
733
+ and index_block % self.gradient_checkpointing_interval == 0
734
+ )
735
+ ):
736
+
737
+ def create_custom_forward(module, return_dict=None):
738
+ def custom_forward(*inputs):
739
+ if return_dict is not None:
740
+ return module(*inputs, return_dict=return_dict)
741
+ else:
742
+ return module(*inputs)
743
+
744
+ return custom_forward
745
+
746
+ ckpt_kwargs: Dict[str, Any] = (
747
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
748
+ )
749
+ hidden_states = torch.utils.checkpoint.checkpoint(
750
+ create_custom_forward(block),
751
+ hidden_states,
752
+ temb,
753
+ image_rotary_emb,
754
+ attention_mask,
755
+ joint_attention_kwargs, # Add this line
756
+
757
+ **ckpt_kwargs,
758
+ )
759
+
760
+ else:
761
+ hidden_states = block(
762
+ hidden_states=hidden_states,
763
+ temb=temb,
764
+ image_rotary_emb=image_rotary_emb,
765
+ attention_mask=attention_mask,
766
+ joint_attention_kwargs= joint_attention_kwargs, # Add this line
767
+
768
+ )
769
+
770
+ # controlnet residual
771
+ if controlnet_single_block_samples is not None:
772
+ interval_control = len(self.single_transformer_blocks) / len(
773
+ controlnet_single_block_samples
774
+ )
775
+ interval_control = int(np.ceil(interval_control))
776
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
777
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
778
+ + controlnet_single_block_samples[index_block // interval_control]
779
+ )
780
+
781
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
782
+
783
+ hidden_states = self.norm_out(hidden_states, temb)
784
+ output = self.proj_out(hidden_states)
785
+
786
+ if USE_PEFT_BACKEND:
787
+ # remove `lora_scale` from each PEFT layer
788
+ unscale_lora_layers(self, lora_scale)
789
+
790
+ if not return_dict:
791
+ return (output,)
792
+
793
+ return Transformer2DModelOutput(sample=output)
794
+
vae/config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.30.0.dev0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "force_upcast": true,
18
+ "in_channels": 3,
19
+ "latent_channels": 16,
20
+ "latents_mean": null,
21
+ "latents_std": null,
22
+ "layers_per_block": 2,
23
+ "mid_block_add_attention": true,
24
+ "norm_num_groups": 32,
25
+ "out_channels": 3,
26
+ "sample_size": 1024,
27
+ "scaling_factor": 0.3611,
28
+ "shift_factor": 0.1159,
29
+ "up_block_types": [
30
+ "UpDecoderBlock2D",
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D"
34
+ ],
35
+ "use_post_quant_conv": false,
36
+ "use_quant_conv": false
37
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5b59a26851551b67ae1fe58d32e76486e1e812def4696a4bea97f16604d40a3
3
+ size 167666902