Text-to-Image
Diffusers
Safetensors
LibreFluxIPAdapterPipeline
neuralvfx commited on
Commit
1600698
·
verified ·
1 Parent(s): 340515d

Upload folder using huggingface_hub

Browse files
Files changed (42) hide show
  1. .gitattributes +19 -35
  2. controlnet/config.json +20 -0
  3. controlnet/diffusion_pytorch_model.safetensors +3 -0
  4. controlnet/net.py +1734 -0
  5. examples/control_ip_example.png +3 -0
  6. examples/david.jpg +3 -0
  7. examples/libre_flux_control_image.png +3 -0
  8. examples/merc.jpeg +0 -0
  9. examples/mona.jpg +3 -0
  10. flux_ip_adapter.py +332 -0
  11. image_encoder/config.json +18 -0
  12. image_encoder/model.safetensors +3 -0
  13. image_encoder/preprocessor_config.json +24 -0
  14. image_encoder/special_tokens_map.json +23 -0
  15. image_encoder/spiece.model +3 -0
  16. image_encoder/tokenizer_config.json +34 -0
  17. ip_adapter.pt +3 -0
  18. model_index.json +40 -0
  19. pipeline.py +1227 -0
  20. scheduler/scheduler_config.json +11 -0
  21. text_encoder/config.json +25 -0
  22. text_encoder/model.safetensors +3 -0
  23. text_encoder_2/config.json +32 -0
  24. text_encoder_2/model-00001-of-00002.safetensors +3 -0
  25. text_encoder_2/model-00002-of-00002.safetensors +3 -0
  26. text_encoder_2/model.safetensors.index.json +226 -0
  27. tokenizer/merges.txt +0 -0
  28. tokenizer/special_tokens_map.json +30 -0
  29. tokenizer/tokenizer_config.json +30 -0
  30. tokenizer/vocab.json +0 -0
  31. tokenizer_2/special_tokens_map.json +125 -0
  32. tokenizer_2/spiece.model +3 -0
  33. tokenizer_2/tokenizer.json +0 -0
  34. tokenizer_2/tokenizer_config.json +940 -0
  35. transformer/config.json +19 -0
  36. transformer/diffusion_pytorch_model-00001-of-00003.safetensors +3 -0
  37. transformer/diffusion_pytorch_model-00002-of-00003.safetensors +3 -0
  38. transformer/diffusion_pytorch_model-00003-of-00003.safetensors +3 -0
  39. transformer/diffusion_pytorch_model.safetensors.index.json +0 -0
  40. transformer/trans.py +794 -0
  41. vae/config.json +37 -0
  42. vae/diffusion_pytorch_model.safetensors +3 -0
.gitattributes CHANGED
@@ -1,35 +1,19 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ examples/side_by_side.png filter=lfs diff=lfs merge=lfs -text
2
+ examples/libre_flux_control_image.png filter=lfs diff=lfs merge=lfs -text
3
+ examples/side_by_side_a.png filter=lfs diff=lfs merge=lfs -text
4
+ examples/side_by_side_b.png filter=lfs diff=lfs merge=lfs -text
5
+ controlnet/diffusion_pytorch_model.safetensors filter=lfs diff=lfs merge=lfs -text
6
+ examples/control_ip_example.png filter=lfs diff=lfs merge=lfs -text
7
+ examples/david.jpg filter=lfs diff=lfs merge=lfs -text
8
+ examples/mona.jpg filter=lfs diff=lfs merge=lfs -text
9
+ image_encoder/model.safetensors filter=lfs diff=lfs merge=lfs -text
10
+ image_encoder/spiece.model filter=lfs diff=lfs merge=lfs -text
11
+ ip_adapter.pt filter=lfs diff=lfs merge=lfs -text
12
+ text_encoder/model.safetensors filter=lfs diff=lfs merge=lfs -text
13
+ text_encoder_2/model-00001-of-00002.safetensors filter=lfs diff=lfs merge=lfs -text
14
+ text_encoder_2/model-00002-of-00002.safetensors filter=lfs diff=lfs merge=lfs -text
15
+ tokenizer_2/spiece.model filter=lfs diff=lfs merge=lfs -text
16
+ transformer/diffusion_pytorch_model-00001-of-00003.safetensors filter=lfs diff=lfs merge=lfs -text
17
+ transformer/diffusion_pytorch_model-00002-of-00003.safetensors filter=lfs diff=lfs merge=lfs -text
18
+ transformer/diffusion_pytorch_model-00003-of-00003.safetensors filter=lfs diff=lfs merge=lfs -text
19
+ vae/diffusion_pytorch_model.safetensors filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
controlnet/config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "LibreFluxControlNetModel",
3
+ "_diffusers_version": "0.32.0",
4
+ "attention_head_dim": 128,
5
+ "axes_dims_rope": [
6
+ 16,
7
+ 56,
8
+ 56
9
+ ],
10
+ "conditioning_embedding_channels": null,
11
+ "guidance_embeds": true,
12
+ "in_channels": 64,
13
+ "joint_attention_dim": 4096,
14
+ "num_attention_heads": 24,
15
+ "num_layers": 2,
16
+ "num_mode": null,
17
+ "num_single_layers": 4,
18
+ "patch_size": 1,
19
+ "pooled_projection_dim": 768
20
+ }
controlnet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06e84cb264fc8bf98cc6c1ed5e53a606d061c4440c5ba9164f941dfce4f054b6
3
+ size 2739920936
controlnet/net.py ADDED
@@ -0,0 +1,1734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # This was modied from the control net repo
17
+
18
+
19
+ import inspect
20
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
+
22
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
23
+
24
+ import numpy as np
25
+ import torch
26
+ from transformers import (
27
+ CLIPTextModel,
28
+ CLIPTokenizer,
29
+ T5EncoderModel,
30
+ T5TokenizerFast,
31
+ )
32
+
33
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
34
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin
35
+ from diffusers.models.autoencoders import AutoencoderKL
36
+ ### MERGEING THESE ###
37
+ # from src.models.transformer import FluxTransformer2DModel
38
+ # from src.models.controlnet_flux import FluxControlNetModel
39
+ #############
40
+
41
+ ##########################################
42
+ ########### ATTENTION MERGE ##############
43
+ ##########################################
44
+
45
+ import torch
46
+ from torch import Tensor, FloatTensor
47
+ from torch.nn import functional as F
48
+ from einops import rearrange
49
+ from diffusers.models.attention_processor import Attention
50
+ from diffusers.models.embeddings import apply_rotary_emb
51
+
52
+
53
+
54
+ def fa3_sdpa(
55
+ q,
56
+ k,
57
+ v,
58
+ ):
59
+ # flash attention 3 sdpa drop-in replacement
60
+ q, k, v = [x.permute(0, 2, 1, 3) for x in [q, k, v]]
61
+ out = flash_attn_func(q, k, v)[0]
62
+ return out.permute(0, 2, 1, 3)
63
+
64
+
65
+ class FluxSingleAttnProcessor3_0:
66
+ r"""
67
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
68
+ """
69
+
70
+ def __init__(self):
71
+ if not hasattr(F, "scaled_dot_product_attention"):
72
+ raise ImportError(
73
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
74
+ )
75
+
76
+ def __call__(
77
+ self,
78
+ attn,
79
+ hidden_states: Tensor,
80
+ encoder_hidden_states: Tensor = None,
81
+ attention_mask: FloatTensor = None,
82
+ image_rotary_emb: Tensor = None,
83
+ ) -> Tensor:
84
+ input_ndim = hidden_states.ndim
85
+
86
+ if input_ndim == 4:
87
+ batch_size, channel, height, width = hidden_states.shape
88
+ hidden_states = hidden_states.view(
89
+ batch_size, channel, height * width
90
+ ).transpose(1, 2)
91
+
92
+ batch_size, _, _ = (
93
+ hidden_states.shape
94
+ if encoder_hidden_states is None
95
+ else encoder_hidden_states.shape
96
+ )
97
+
98
+ query = attn.to_q(hidden_states)
99
+ if encoder_hidden_states is None:
100
+ encoder_hidden_states = hidden_states
101
+
102
+ key = attn.to_k(encoder_hidden_states)
103
+ value = attn.to_v(encoder_hidden_states)
104
+
105
+ inner_dim = key.shape[-1]
106
+ head_dim = inner_dim // attn.heads
107
+
108
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
109
+
110
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
111
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
112
+
113
+ if attn.norm_q is not None:
114
+ query = attn.norm_q(query)
115
+ if attn.norm_k is not None:
116
+ key = attn.norm_k(key)
117
+
118
+ # Apply RoPE if needed
119
+ if image_rotary_emb is not None:
120
+ query = apply_rotary_emb(query, image_rotary_emb)
121
+ key = apply_rotary_emb(key, image_rotary_emb)
122
+
123
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
124
+ # TODO: add support for attn.scale when we move to Torch 2.1
125
+ # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
126
+ hidden_states = fa3_sdpa(query, key, value)
127
+ hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)")
128
+
129
+ hidden_states = hidden_states.transpose(1, 2).reshape(
130
+ batch_size, -1, attn.heads * head_dim
131
+ )
132
+ hidden_states = hidden_states.to(query.dtype)
133
+
134
+ if input_ndim == 4:
135
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
136
+ batch_size, channel, height, width
137
+ )
138
+
139
+ return hidden_states
140
+
141
+
142
+ class FluxAttnProcessor3_0:
143
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
144
+
145
+ def __init__(self):
146
+ if not hasattr(F, "scaled_dot_product_attention"):
147
+ raise ImportError(
148
+ "FluxAttnProcessor3_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
149
+ )
150
+
151
+ def __call__(
152
+ self,
153
+ attn,
154
+ hidden_states: FloatTensor,
155
+ encoder_hidden_states: FloatTensor = None,
156
+ attention_mask: FloatTensor = None,
157
+ image_rotary_emb: Tensor = None,
158
+ ) -> FloatTensor:
159
+ input_ndim = hidden_states.ndim
160
+ if input_ndim == 4:
161
+ batch_size, channel, height, width = hidden_states.shape
162
+ hidden_states = hidden_states.view(
163
+ batch_size, channel, height * width
164
+ ).transpose(1, 2)
165
+ context_input_ndim = encoder_hidden_states.ndim
166
+ if context_input_ndim == 4:
167
+ batch_size, channel, height, width = encoder_hidden_states.shape
168
+ encoder_hidden_states = encoder_hidden_states.view(
169
+ batch_size, channel, height * width
170
+ ).transpose(1, 2)
171
+
172
+ batch_size = encoder_hidden_states.shape[0]
173
+
174
+ # `sample` projections.
175
+ query = attn.to_q(hidden_states)
176
+ key = attn.to_k(hidden_states)
177
+ value = attn.to_v(hidden_states)
178
+
179
+ inner_dim = key.shape[-1]
180
+ head_dim = inner_dim // attn.heads
181
+
182
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
183
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
184
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
185
+
186
+ if attn.norm_q is not None:
187
+ query = attn.norm_q(query)
188
+ if attn.norm_k is not None:
189
+ key = attn.norm_k(key)
190
+
191
+ # `context` projections.
192
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
193
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
194
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
195
+
196
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
197
+ batch_size, -1, attn.heads, head_dim
198
+ ).transpose(1, 2)
199
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
200
+ batch_size, -1, attn.heads, head_dim
201
+ ).transpose(1, 2)
202
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
203
+ batch_size, -1, attn.heads, head_dim
204
+ ).transpose(1, 2)
205
+
206
+ if attn.norm_added_q is not None:
207
+ encoder_hidden_states_query_proj = attn.norm_added_q(
208
+ encoder_hidden_states_query_proj
209
+ )
210
+ if attn.norm_added_k is not None:
211
+ encoder_hidden_states_key_proj = attn.norm_added_k(
212
+ encoder_hidden_states_key_proj
213
+ )
214
+
215
+ # attention
216
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
217
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
218
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
219
+
220
+ if image_rotary_emb is not None:
221
+
222
+ query = apply_rotary_emb(query, image_rotary_emb)
223
+ key = apply_rotary_emb(key, image_rotary_emb)
224
+
225
+ # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
226
+ hidden_states = fa3_sdpa(query, key, value)
227
+ hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)")
228
+
229
+ hidden_states = hidden_states.transpose(1, 2).reshape(
230
+ batch_size, -1, attn.heads * head_dim
231
+ )
232
+ hidden_states = hidden_states.to(query.dtype)
233
+
234
+ encoder_hidden_states, hidden_states = (
235
+ hidden_states[:, : encoder_hidden_states.shape[1]],
236
+ hidden_states[:, encoder_hidden_states.shape[1] :],
237
+ )
238
+
239
+ # linear proj
240
+ hidden_states = attn.to_out[0](hidden_states)
241
+ # dropout
242
+ hidden_states = attn.to_out[1](hidden_states)
243
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
244
+
245
+ if input_ndim == 4:
246
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
247
+ batch_size, channel, height, width
248
+ )
249
+ if context_input_ndim == 4:
250
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(
251
+ batch_size, channel, height, width
252
+ )
253
+
254
+ return hidden_states, encoder_hidden_states
255
+
256
+
257
+
258
+ class FluxFusedSDPAProcessor:
259
+ """
260
+ Fused QKV processor using PyTorch's scaled_dot_product_attention.
261
+ Uses fused projections but splits for attention computation.
262
+ """
263
+
264
+ def __init__(self):
265
+ if not hasattr(F, "scaled_dot_product_attention"):
266
+ raise ImportError(
267
+ "FluxFusedSDPAProcessor requires PyTorch 2.0+ for scaled_dot_product_attention"
268
+ )
269
+
270
+ def __call__(
271
+ self,
272
+ attn,
273
+ hidden_states: FloatTensor,
274
+ encoder_hidden_states: FloatTensor = None,
275
+ attention_mask: FloatTensor = None,
276
+ image_rotary_emb: Tensor = None,
277
+ ) -> FloatTensor:
278
+ input_ndim = hidden_states.ndim
279
+ if input_ndim == 4:
280
+ batch_size, channel, height, width = hidden_states.shape
281
+ hidden_states = hidden_states.view(
282
+ batch_size, channel, height * width
283
+ ).transpose(1, 2)
284
+
285
+ context_input_ndim = (
286
+ encoder_hidden_states.ndim if encoder_hidden_states is not None else None
287
+ )
288
+ if context_input_ndim == 4:
289
+ batch_size, channel, height, width = encoder_hidden_states.shape
290
+ encoder_hidden_states = encoder_hidden_states.view(
291
+ batch_size, channel, height * width
292
+ ).transpose(1, 2)
293
+
294
+ batch_size = (
295
+ encoder_hidden_states.shape[0]
296
+ if encoder_hidden_states is not None
297
+ else hidden_states.shape[0]
298
+ )
299
+
300
+ # Single attention case (no encoder states)
301
+ if encoder_hidden_states is None:
302
+ # Use fused QKV projection
303
+ qkv = attn.to_qkv(hidden_states) # (batch, seq_len, 3 * inner_dim)
304
+ inner_dim = qkv.shape[-1] // 3
305
+ head_dim = inner_dim // attn.heads
306
+ seq_len = hidden_states.shape[1]
307
+
308
+ # Split and reshape
309
+ qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim)
310
+ query, key, value = qkv.unbind(
311
+ dim=2
312
+ ) # Each is (batch, seq_len, heads, head_dim)
313
+
314
+ # Transpose to (batch, heads, seq_len, head_dim)
315
+ query = query.transpose(1, 2)
316
+ key = key.transpose(1, 2)
317
+ value = value.transpose(1, 2)
318
+
319
+ # Apply norms if needed
320
+ if attn.norm_q is not None:
321
+ query = attn.norm_q(query)
322
+ if attn.norm_k is not None:
323
+ key = attn.norm_k(key)
324
+
325
+ # Apply RoPE if needed
326
+ if image_rotary_emb is not None:
327
+ query = apply_rotary_emb(query, image_rotary_emb)
328
+ key = apply_rotary_emb(key, image_rotary_emb)
329
+
330
+ # SDPA
331
+ hidden_states = F.scaled_dot_product_attention(
332
+ query,
333
+ key,
334
+ value,
335
+ attn_mask=attention_mask,
336
+ dropout_p=0.0,
337
+ is_causal=False,
338
+ )
339
+
340
+ # Reshape back
341
+ hidden_states = hidden_states.transpose(1, 2).reshape(
342
+ batch_size, -1, attn.heads * head_dim
343
+ )
344
+ hidden_states = hidden_states.to(query.dtype)
345
+
346
+ if input_ndim == 4:
347
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
348
+ batch_size, channel, height, width
349
+ )
350
+
351
+ return hidden_states
352
+
353
+ # Joint attention case (with encoder states)
354
+ else:
355
+ # Process self-attention QKV
356
+ qkv = attn.to_qkv(hidden_states)
357
+ inner_dim = qkv.shape[-1] // 3
358
+ head_dim = inner_dim // attn.heads
359
+ seq_len = hidden_states.shape[1]
360
+
361
+ qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim)
362
+ query, key, value = qkv.unbind(dim=2)
363
+
364
+ # Transpose to (batch, heads, seq_len, head_dim)
365
+ query = query.transpose(1, 2)
366
+ key = key.transpose(1, 2)
367
+ value = value.transpose(1, 2)
368
+
369
+ # Apply norms if needed
370
+ if attn.norm_q is not None:
371
+ query = attn.norm_q(query)
372
+ if attn.norm_k is not None:
373
+ key = attn.norm_k(key)
374
+
375
+ # Process encoder QKV
376
+ encoder_seq_len = encoder_hidden_states.shape[1]
377
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
378
+ encoder_qkv = encoder_qkv.view(
379
+ batch_size, encoder_seq_len, 3, attn.heads, head_dim
380
+ )
381
+ encoder_query, encoder_key, encoder_value = encoder_qkv.unbind(dim=2)
382
+
383
+ # Transpose to (batch, heads, seq_len, head_dim)
384
+ encoder_query = encoder_query.transpose(1, 2)
385
+ encoder_key = encoder_key.transpose(1, 2)
386
+ encoder_value = encoder_value.transpose(1, 2)
387
+
388
+ # Apply encoder norms if needed
389
+ if attn.norm_added_q is not None:
390
+ encoder_query = attn.norm_added_q(encoder_query)
391
+ if attn.norm_added_k is not None:
392
+ encoder_key = attn.norm_added_k(encoder_key)
393
+
394
+ # Concatenate encoder and self-attention
395
+ query = torch.cat([encoder_query, query], dim=2)
396
+ key = torch.cat([encoder_key, key], dim=2)
397
+ value = torch.cat([encoder_value, value], dim=2)
398
+
399
+ # Apply RoPE if needed
400
+ if image_rotary_emb is not None:
401
+ query = apply_rotary_emb(query, image_rotary_emb)
402
+ key = apply_rotary_emb(key, image_rotary_emb)
403
+
404
+ # SDPA
405
+ hidden_states = F.scaled_dot_product_attention(
406
+ query,
407
+ key,
408
+ value,
409
+ attn_mask=attention_mask,
410
+ dropout_p=0.0,
411
+ is_causal=False,
412
+ )
413
+
414
+ # Reshape: (batch, heads, seq_len, head_dim) -> (batch, seq_len, heads * head_dim)
415
+ hidden_states = hidden_states.transpose(1, 2).reshape(
416
+ batch_size, -1, attn.heads * head_dim
417
+ )
418
+ hidden_states = hidden_states.to(query.dtype)
419
+
420
+ # Split encoder and self outputs
421
+ encoder_hidden_states = hidden_states[:, :encoder_seq_len]
422
+ hidden_states = hidden_states[:, encoder_seq_len:]
423
+
424
+ # Output projections
425
+ hidden_states = attn.to_out[0](hidden_states)
426
+ hidden_states = attn.to_out[1](hidden_states) # dropout
427
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
428
+
429
+ # Reshape if needed
430
+ if input_ndim == 4:
431
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
432
+ batch_size, channel, height, width
433
+ )
434
+ if context_input_ndim == 4:
435
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(
436
+ batch_size, channel, height, width
437
+ )
438
+
439
+ return hidden_states, encoder_hidden_states
440
+
441
+
442
+ class FluxSingleFusedSDPAProcessor:
443
+ """
444
+ Fused QKV processor for single attention (no encoder states).
445
+ Simpler version for self-attention only blocks.
446
+ """
447
+
448
+ def __init__(self):
449
+ if not hasattr(F, "scaled_dot_product_attention"):
450
+ raise ImportError(
451
+ "FluxSingleFusedSDPAProcessor requires PyTorch 2.0+ for scaled_dot_product_attention"
452
+ )
453
+
454
+ def __call__(
455
+ self,
456
+ attn,
457
+ hidden_states: Tensor,
458
+ encoder_hidden_states: Tensor = None,
459
+ attention_mask: FloatTensor = None,
460
+ image_rotary_emb: Tensor = None,
461
+ ) -> Tensor:
462
+ input_ndim = hidden_states.ndim
463
+ if input_ndim == 4:
464
+ batch_size, channel, height, width = hidden_states.shape
465
+ hidden_states = hidden_states.view(
466
+ batch_size, channel, height * width
467
+ ).transpose(1, 2)
468
+
469
+ batch_size, seq_len, _ = hidden_states.shape
470
+
471
+ # Use fused QKV projection
472
+ qkv = attn.to_qkv(hidden_states) # (batch, seq_len, 3 * inner_dim)
473
+ inner_dim = qkv.shape[-1] // 3
474
+ head_dim = inner_dim // attn.heads
475
+
476
+ # Split and reshape in one go
477
+ qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim)
478
+ qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, L, D) – still strided
479
+ query, key, value = [
480
+ t.contiguous() for t in qkv.unbind(0) # make each view dense
481
+ ]
482
+ # Now each is (batch, heads, seq_len, head_dim)
483
+
484
+ # Apply norms if needed
485
+ if attn.norm_q is not None:
486
+ query = attn.norm_q(query)
487
+ if attn.norm_k is not None:
488
+ key = attn.norm_k(key)
489
+
490
+ # Apply RoPE if needed
491
+ if image_rotary_emb is not None:
492
+ query = apply_rotary_emb(query, image_rotary_emb)
493
+ key = apply_rotary_emb(key, image_rotary_emb)
494
+
495
+ # SDPA
496
+ hidden_states = F.scaled_dot_product_attention(
497
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
498
+ )
499
+
500
+ # Reshape back
501
+ hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)")
502
+ hidden_states = hidden_states.to(query.dtype)
503
+
504
+ if input_ndim == 4:
505
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
506
+ batch_size, channel, height, width
507
+ )
508
+
509
+ return hidden_states
510
+
511
+ #################################
512
+ ##### TRANSFORMER MERGE #########
513
+ #################################
514
+
515
+ from typing import Any, Dict, List, Optional, Tuple, Union
516
+
517
+ import torch
518
+ import torch.nn as nn
519
+ import torch.nn.functional as F
520
+ import numpy as np
521
+
522
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
523
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
524
+ from diffusers.models.attention import FeedForward
525
+ from diffusers.models.attention_processor import (
526
+ Attention,
527
+ AttentionProcessor,
528
+ )
529
+ from diffusers.models.modeling_utils import ModelMixin
530
+ from diffusers.models.normalization import (
531
+ AdaLayerNormContinuous,
532
+ AdaLayerNormZero,
533
+ AdaLayerNormZeroSingle,
534
+ )
535
+ from diffusers.utils import (
536
+ USE_PEFT_BACKEND,
537
+ is_torch_version,
538
+ logging,
539
+ scale_lora_layers,
540
+ unscale_lora_layers,
541
+ )
542
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
543
+ from diffusers.models.embeddings import (
544
+ CombinedTimestepGuidanceTextProjEmbeddings,
545
+ CombinedTimestepTextProjEmbeddings,
546
+ FluxPosEmbed,
547
+ )
548
+
549
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
550
+ from diffusers import FluxTransformer2DModel as OriginalFluxTransformer2DModel
551
+
552
+
553
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
554
+
555
+ is_flash_attn_available = False
556
+
557
+
558
+
559
+ class FluxAttnProcessor2_0:
560
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
561
+
562
+ def __init__(self):
563
+ if not hasattr(F, "scaled_dot_product_attention"):
564
+ raise ImportError(
565
+ "FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
566
+ )
567
+
568
+ def __call__(
569
+ self,
570
+ attn: Attention,
571
+ hidden_states: torch.FloatTensor,
572
+ encoder_hidden_states: torch.FloatTensor = None,
573
+ attention_mask: Optional[torch.FloatTensor] = None,
574
+ image_rotary_emb: Optional[torch.Tensor] = None,
575
+ ) -> torch.FloatTensor:
576
+ batch_size, _, _ = (
577
+ hidden_states.shape
578
+ if encoder_hidden_states is None
579
+ else encoder_hidden_states.shape
580
+ )
581
+
582
+ # `sample` projections.
583
+ query = attn.to_q(hidden_states)
584
+ key = attn.to_k(hidden_states)
585
+ value = attn.to_v(hidden_states)
586
+
587
+ inner_dim = key.shape[-1]
588
+ head_dim = inner_dim // attn.heads
589
+
590
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
591
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
592
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
593
+
594
+ if attn.norm_q is not None:
595
+ query = attn.norm_q(query)
596
+ if attn.norm_k is not None:
597
+ key = attn.norm_k(key)
598
+
599
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
600
+ if encoder_hidden_states is not None:
601
+ # `context` projections.
602
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
603
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
604
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
605
+
606
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
607
+ batch_size, -1, attn.heads, head_dim
608
+ ).transpose(1, 2)
609
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
610
+ batch_size, -1, attn.heads, head_dim
611
+ ).transpose(1, 2)
612
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
613
+ batch_size, -1, attn.heads, head_dim
614
+ ).transpose(1, 2)
615
+
616
+ if attn.norm_added_q is not None:
617
+ encoder_hidden_states_query_proj = attn.norm_added_q(
618
+ encoder_hidden_states_query_proj
619
+ )
620
+ if attn.norm_added_k is not None:
621
+ encoder_hidden_states_key_proj = attn.norm_added_k(
622
+ encoder_hidden_states_key_proj
623
+ )
624
+
625
+ # attention
626
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
627
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
628
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
629
+
630
+ if image_rotary_emb is not None:
631
+ from diffusers.models.embeddings import apply_rotary_emb
632
+
633
+ query = apply_rotary_emb(query, image_rotary_emb)
634
+ key = apply_rotary_emb(key, image_rotary_emb)
635
+
636
+ if attention_mask is not None:
637
+ #print ('Attention Used')
638
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
639
+ attention_mask = (attention_mask > 0).bool()
640
+ # Edit 17 - match attn dtype to query d-type
641
+ attention_mask = attention_mask.to(
642
+ device=hidden_states.device, dtype=query.dtype
643
+ )
644
+
645
+ hidden_states = F.scaled_dot_product_attention(
646
+ query,
647
+ key,
648
+ value,
649
+ dropout_p=0.0,
650
+ is_causal=False,
651
+ attn_mask=attention_mask,
652
+ )
653
+ hidden_states = hidden_states.transpose(1, 2).reshape(
654
+ batch_size, -1, attn.heads * head_dim
655
+ )
656
+ hidden_states = hidden_states.to(query.dtype)
657
+
658
+ if encoder_hidden_states is not None:
659
+ encoder_hidden_states, hidden_states = (
660
+ hidden_states[:, : encoder_hidden_states.shape[1]],
661
+ hidden_states[:, encoder_hidden_states.shape[1] :],
662
+ )
663
+
664
+ # linear proj
665
+ hidden_states = attn.to_out[0](hidden_states)
666
+ # dropout
667
+ hidden_states = attn.to_out[1](hidden_states)
668
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
669
+
670
+ return hidden_states, encoder_hidden_states
671
+ return hidden_states
672
+
673
+
674
+ def expand_flux_attention_mask(
675
+ hidden_states: torch.Tensor,
676
+ attn_mask: torch.Tensor,
677
+ ) -> torch.Tensor:
678
+ """
679
+ Expand a mask so that the image is included.
680
+ """
681
+ bsz = attn_mask.shape[0]
682
+ assert bsz == hidden_states.shape[0]
683
+ residual_seq_len = hidden_states.shape[1]
684
+ mask_seq_len = attn_mask.shape[1]
685
+
686
+ expanded_mask = torch.ones(bsz, residual_seq_len)
687
+ expanded_mask[:, :mask_seq_len] = attn_mask
688
+
689
+ return expanded_mask
690
+
691
+
692
+ @maybe_allow_in_graph
693
+ class FluxSingleTransformerBlock(nn.Module):
694
+ r"""
695
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
696
+
697
+ Reference: https://arxiv.org/abs/2403.03206
698
+
699
+ Parameters:
700
+ dim (`int`): The number of channels in the input and output.
701
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
702
+ attention_head_dim (`int`): The number of channels in each head.
703
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
704
+ processing of `context` conditions.
705
+ """
706
+
707
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
708
+ super().__init__()
709
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
710
+
711
+ self.norm = AdaLayerNormZeroSingle(dim)
712
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
713
+ self.act_mlp = nn.GELU(approximate="tanh")
714
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
715
+
716
+ processor = FluxAttnProcessor2_0()
717
+ self.attn = Attention(
718
+ query_dim=dim,
719
+ cross_attention_dim=None,
720
+ dim_head=attention_head_dim,
721
+ heads=num_attention_heads,
722
+ out_dim=dim,
723
+ bias=True,
724
+ processor=processor,
725
+ qk_norm="rms_norm",
726
+ eps=1e-6,
727
+ pre_only=True,
728
+ )
729
+
730
+ def forward(
731
+ self,
732
+ hidden_states: torch.FloatTensor,
733
+ temb: torch.FloatTensor,
734
+ image_rotary_emb=None,
735
+ attention_mask: Optional[torch.Tensor] = None,
736
+ ):
737
+ residual = hidden_states
738
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
739
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
740
+
741
+ if attention_mask is not None:
742
+ attention_mask = expand_flux_attention_mask(
743
+ hidden_states,
744
+ attention_mask,
745
+ )
746
+
747
+ attn_output = self.attn(
748
+ hidden_states=norm_hidden_states,
749
+ image_rotary_emb=image_rotary_emb,
750
+ attention_mask=attention_mask,
751
+ )
752
+
753
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
754
+ gate = gate.unsqueeze(1)
755
+ hidden_states = gate * self.proj_out(hidden_states)
756
+ hidden_states = residual + hidden_states
757
+
758
+ if hidden_states.dtype == torch.float16:
759
+ hidden_states = hidden_states.clip(-65504, 65504)
760
+
761
+ return hidden_states
762
+
763
+
764
+ @maybe_allow_in_graph
765
+ class FluxTransformerBlock(nn.Module):
766
+ r"""
767
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
768
+
769
+ Reference: https://arxiv.org/abs/2403.03206
770
+
771
+ Parameters:
772
+ dim (`int`): The number of channels in the input and output.
773
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
774
+ attention_head_dim (`int`): The number of channels in each head.
775
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
776
+ processing of `context` conditions.
777
+ """
778
+
779
+ def __init__(
780
+ self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6
781
+ ):
782
+ super().__init__()
783
+
784
+ self.norm1 = AdaLayerNormZero(dim)
785
+
786
+ self.norm1_context = AdaLayerNormZero(dim)
787
+
788
+ if hasattr(F, "scaled_dot_product_attention"):
789
+ processor = FluxAttnProcessor2_0()
790
+ else:
791
+ raise ValueError(
792
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
793
+ )
794
+ self.attn = Attention(
795
+ query_dim=dim,
796
+ cross_attention_dim=None,
797
+ added_kv_proj_dim=dim,
798
+ dim_head=attention_head_dim,
799
+ heads=num_attention_heads,
800
+ out_dim=dim,
801
+ context_pre_only=False,
802
+ bias=True,
803
+ processor=processor,
804
+ qk_norm=qk_norm,
805
+ eps=eps,
806
+ )
807
+
808
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
809
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
810
+
811
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
812
+ self.ff_context = FeedForward(
813
+ dim=dim, dim_out=dim, activation_fn="gelu-approximate"
814
+ )
815
+
816
+ # let chunk size default to None
817
+ self._chunk_size = None
818
+ self._chunk_dim = 0
819
+
820
+ def forward(
821
+ self,
822
+ hidden_states: torch.FloatTensor,
823
+ encoder_hidden_states: torch.FloatTensor,
824
+ temb: torch.FloatTensor,
825
+ image_rotary_emb=None,
826
+ attention_mask: Optional[torch.Tensor] = None,
827
+ ):
828
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
829
+ hidden_states, emb=temb
830
+ )
831
+
832
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
833
+ self.norm1_context(encoder_hidden_states, emb=temb)
834
+ )
835
+
836
+ if attention_mask is not None:
837
+ attention_mask = expand_flux_attention_mask(
838
+ torch.cat([encoder_hidden_states, hidden_states], dim=1),
839
+ attention_mask,
840
+ )
841
+
842
+ # Attention.
843
+ attention_outputs = self.attn(
844
+ hidden_states=norm_hidden_states,
845
+ encoder_hidden_states=norm_encoder_hidden_states,
846
+ image_rotary_emb=image_rotary_emb,
847
+ attention_mask=attention_mask,
848
+ )
849
+ if len(attention_outputs) == 2:
850
+ attn_output, context_attn_output = attention_outputs
851
+ elif len(attention_outputs) == 3:
852
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
853
+
854
+ # Process attention outputs for the `hidden_states`.
855
+ attn_output = gate_msa.unsqueeze(1) * attn_output
856
+ hidden_states = hidden_states + attn_output
857
+
858
+ norm_hidden_states = self.norm2(hidden_states)
859
+ norm_hidden_states = (
860
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
861
+ )
862
+
863
+ ff_output = self.ff(norm_hidden_states)
864
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
865
+
866
+ hidden_states = hidden_states + ff_output
867
+ if len(attention_outputs) == 3:
868
+ hidden_states = hidden_states + ip_attn_output
869
+
870
+ # Process attention outputs for the `encoder_hidden_states`.
871
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
872
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
873
+
874
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
875
+ norm_encoder_hidden_states = (
876
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
877
+ + c_shift_mlp[:, None]
878
+ )
879
+
880
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
881
+ encoder_hidden_states = (
882
+ encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
883
+ )
884
+
885
+ if encoder_hidden_states.dtype == torch.float16:
886
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
887
+
888
+ return encoder_hidden_states, hidden_states
889
+
890
+
891
+ class LibreFluxTransformer2DModel(
892
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
893
+ ):
894
+ """
895
+ The Transformer model introduced in Flux.
896
+
897
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
898
+
899
+ Parameters:
900
+ patch_size (`int`): Patch size to turn the input data into small patches.
901
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
902
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
903
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
904
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
905
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
906
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
907
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
908
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
909
+ """
910
+
911
+ _supports_gradient_checkpointing = True
912
+
913
+ @register_to_config
914
+ def __init__(
915
+ self,
916
+ patch_size: int = 1,
917
+ in_channels: int = 64,
918
+ num_layers: int = 19,
919
+ num_single_layers: int = 38,
920
+ attention_head_dim: int = 128,
921
+ num_attention_heads: int = 24,
922
+ joint_attention_dim: int = 4096,
923
+ pooled_projection_dim: int = 768,
924
+ guidance_embeds: bool = False,
925
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
926
+ ):
927
+ super().__init__()
928
+ self.out_channels = in_channels
929
+ self.inner_dim = (
930
+ self.config.num_attention_heads * self.config.attention_head_dim
931
+ )
932
+
933
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
934
+ text_time_guidance_cls = (
935
+ CombinedTimestepGuidanceTextProjEmbeddings ### 3 input forward (timestep, guidance, pooled_projection)
936
+ if guidance_embeds
937
+ else CombinedTimestepTextProjEmbeddings #### 2 input forward (timestep, pooled_projection)
938
+ )
939
+ self.time_text_embed = text_time_guidance_cls(
940
+ embedding_dim=self.inner_dim,
941
+ pooled_projection_dim=self.config.pooled_projection_dim,
942
+ )
943
+
944
+ self.context_embedder = nn.Linear(
945
+ self.config.joint_attention_dim, self.inner_dim
946
+ )
947
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
948
+
949
+ self.transformer_blocks = nn.ModuleList(
950
+ [
951
+ FluxTransformerBlock(
952
+ dim=self.inner_dim,
953
+ num_attention_heads=self.config.num_attention_heads,
954
+ attention_head_dim=self.config.attention_head_dim,
955
+ )
956
+ for i in range(self.config.num_layers)
957
+ ]
958
+ )
959
+
960
+ self.single_transformer_blocks = nn.ModuleList(
961
+ [
962
+ FluxSingleTransformerBlock(
963
+ dim=self.inner_dim,
964
+ num_attention_heads=self.config.num_attention_heads,
965
+ attention_head_dim=self.config.attention_head_dim,
966
+ )
967
+ for i in range(self.config.num_single_layers)
968
+ ]
969
+ )
970
+
971
+ self.norm_out = AdaLayerNormContinuous(
972
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
973
+ )
974
+ self.proj_out = nn.Linear(
975
+ self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
976
+ )
977
+
978
+ self.gradient_checkpointing = False
979
+ # added for users to disable checkpointing every nth step
980
+ self.gradient_checkpointing_interval = None
981
+
982
+ def set_gradient_checkpointing_interval(self, value: int):
983
+ self.gradient_checkpointing_interval = value
984
+
985
+ @property
986
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
987
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
988
+ r"""
989
+ Returns:
990
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
991
+ indexed by its weight name.
992
+ """
993
+ # set recursively
994
+ processors = {}
995
+
996
+ def fn_recursive_add_processors(
997
+ name: str,
998
+ module: torch.nn.Module,
999
+ processors: Dict[str, AttentionProcessor],
1000
+ ):
1001
+ if hasattr(module, "get_processor"):
1002
+ processors[f"{name}.processor"] = module.get_processor()
1003
+
1004
+ for sub_name, child in module.named_children():
1005
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
1006
+
1007
+ return processors
1008
+
1009
+ for name, module in self.named_children():
1010
+ fn_recursive_add_processors(name, module, processors)
1011
+
1012
+ return processors
1013
+
1014
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
1015
+ def set_attn_processor(
1016
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
1017
+ ):
1018
+ r"""
1019
+ Sets the attention processor to use to compute attention.
1020
+
1021
+ Parameters:
1022
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
1023
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
1024
+ for **all** `Attention` layers.
1025
+
1026
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
1027
+ processor. This is strongly recommended when setting trainable attention processors.
1028
+
1029
+ """
1030
+ count = len(self.attn_processors.keys())
1031
+
1032
+ if isinstance(processor, dict) and len(processor) != count:
1033
+ raise ValueError(
1034
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
1035
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
1036
+ )
1037
+
1038
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
1039
+ if hasattr(module, "set_processor"):
1040
+ if not isinstance(processor, dict):
1041
+ module.set_processor(processor)
1042
+ else:
1043
+ module.set_processor(processor.pop(f"{name}.processor"))
1044
+
1045
+ for sub_name, child in module.named_children():
1046
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
1047
+
1048
+ for name, module in self.named_children():
1049
+ fn_recursive_attn_processor(name, module, processor)
1050
+
1051
+ def forward(
1052
+ self,
1053
+ hidden_states: torch.Tensor,
1054
+ encoder_hidden_states: torch.Tensor = None,
1055
+ pooled_projections: torch.Tensor = None,
1056
+ timestep: torch.LongTensor = None,
1057
+ img_ids: torch.Tensor = None,
1058
+ txt_ids: torch.Tensor = None,
1059
+ guidance: torch.Tensor = None,
1060
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
1061
+ controlnet_block_samples=None,
1062
+ controlnet_single_block_samples=None,
1063
+ return_dict: bool = True,
1064
+ attention_mask: Optional[torch.Tensor] = None,
1065
+ controlnet_blocks_repeat: bool = False,
1066
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
1067
+ """
1068
+ The [`FluxTransformer2DModel`] forward method.
1069
+
1070
+ Args:
1071
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
1072
+ Input `hidden_states`.
1073
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
1074
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
1075
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
1076
+ from the embeddings of input conditions.
1077
+ timestep ( `torch.LongTensor`):
1078
+ Used to indicate denoising step.
1079
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
1080
+ A list of tensors that if specified are added to the residuals of transformer blocks.
1081
+ joint_attention_kwargs (`dict`, *optional*):
1082
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1083
+ `self.processor` in
1084
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1085
+ return_dict (`bool`, *optional*, defaults to `True`):
1086
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
1087
+ tuple.
1088
+
1089
+ Returns:
1090
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
1091
+ `tuple` where the first element is the sample tensor.
1092
+ """
1093
+ if joint_attention_kwargs is not None:
1094
+ joint_attention_kwargs = joint_attention_kwargs.copy()
1095
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
1096
+ else:
1097
+ lora_scale = 1.0
1098
+
1099
+ if USE_PEFT_BACKEND:
1100
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1101
+ scale_lora_layers(self, lora_scale)
1102
+ else:
1103
+ if (
1104
+ joint_attention_kwargs is not None
1105
+ and joint_attention_kwargs.get("scale", None) is not None
1106
+ ):
1107
+ logger.warning(
1108
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
1109
+ )
1110
+ hidden_states = self.x_embedder(hidden_states)
1111
+
1112
+ timestep = timestep.to(hidden_states.dtype) * 1000
1113
+ if guidance is not None:
1114
+ guidance = guidance.to(hidden_states.dtype) * 1000
1115
+ else:
1116
+ guidance = None
1117
+
1118
+ #print( self.time_text_embed)
1119
+ temb = (
1120
+ self.time_text_embed(timestep,pooled_projections)
1121
+ # Edit 1 # Charlie NOT NEEDED - UNDONE
1122
+ if guidance is None
1123
+ else self.time_text_embed(timestep, guidance, pooled_projections)
1124
+ )
1125
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
1126
+
1127
+ if txt_ids.ndim == 3:
1128
+ txt_ids = txt_ids[0]
1129
+ if img_ids.ndim == 3:
1130
+ img_ids = img_ids[0]
1131
+
1132
+ ids = torch.cat((txt_ids, img_ids), dim=0)
1133
+
1134
+ image_rotary_emb = self.pos_embed(ids)
1135
+
1136
+ # IP adapter
1137
+ if (
1138
+ joint_attention_kwargs is not None
1139
+ and "ip_adapter_image_embeds" in joint_attention_kwargs
1140
+ ):
1141
+ ip_adapter_image_embeds = joint_attention_kwargs.pop(
1142
+ "ip_adapter_image_embeds"
1143
+ )
1144
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
1145
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
1146
+
1147
+ for index_block, block in enumerate(self.transformer_blocks):
1148
+ if (
1149
+ self.training
1150
+ and self.gradient_checkpointing
1151
+ and (
1152
+ self.gradient_checkpointing_interval is None
1153
+ or index_block % self.gradient_checkpointing_interval == 0
1154
+ )
1155
+ ):
1156
+
1157
+ def create_custom_forward(module, return_dict=None):
1158
+ def custom_forward(*inputs):
1159
+ if return_dict is not None:
1160
+ return module(*inputs, return_dict=return_dict)
1161
+ else:
1162
+ return module(*inputs)
1163
+
1164
+ return custom_forward
1165
+
1166
+ ckpt_kwargs: Dict[str, Any] = (
1167
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1168
+ )
1169
+ encoder_hidden_states, hidden_states = (
1170
+ torch.utils.checkpoint.checkpoint(
1171
+ create_custom_forward(block),
1172
+ hidden_states,
1173
+ encoder_hidden_states,
1174
+ temb,
1175
+ image_rotary_emb,
1176
+ attention_mask,
1177
+ **ckpt_kwargs,
1178
+ )
1179
+ )
1180
+
1181
+ else:
1182
+ encoder_hidden_states, hidden_states = block(
1183
+ hidden_states=hidden_states,
1184
+ encoder_hidden_states=encoder_hidden_states,
1185
+ temb=temb,
1186
+ image_rotary_emb=image_rotary_emb,
1187
+ attention_mask=attention_mask,
1188
+ )
1189
+
1190
+ # controlnet residual
1191
+ if controlnet_block_samples is not None:
1192
+ interval_control = len(self.transformer_blocks) / len(
1193
+ controlnet_block_samples
1194
+ )
1195
+ interval_control = int(np.ceil(interval_control))
1196
+ # For Xlabs ControlNet.
1197
+ if controlnet_blocks_repeat:
1198
+ hidden_states = (
1199
+ hidden_states
1200
+ + controlnet_block_samples[
1201
+ index_block % len(controlnet_block_samples)
1202
+ ]
1203
+ )
1204
+ else:
1205
+ hidden_states = (
1206
+ hidden_states
1207
+ + controlnet_block_samples[index_block // interval_control]
1208
+ )
1209
+
1210
+ # Flux places the text tokens in front of the image tokens in the
1211
+ # sequence.
1212
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
1213
+
1214
+ for index_block, block in enumerate(self.single_transformer_blocks):
1215
+ if (
1216
+ self.training
1217
+ and self.gradient_checkpointing
1218
+ or (
1219
+ self.gradient_checkpointing_interval is not None
1220
+ and index_block % self.gradient_checkpointing_interval == 0
1221
+ )
1222
+ ):
1223
+
1224
+ def create_custom_forward(module, return_dict=None):
1225
+ def custom_forward(*inputs):
1226
+ if return_dict is not None:
1227
+ return module(*inputs, return_dict=return_dict)
1228
+ else:
1229
+ return module(*inputs)
1230
+
1231
+ return custom_forward
1232
+
1233
+ ckpt_kwargs: Dict[str, Any] = (
1234
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1235
+ )
1236
+ hidden_states = torch.utils.checkpoint.checkpoint(
1237
+ create_custom_forward(block),
1238
+ hidden_states,
1239
+ temb,
1240
+ image_rotary_emb,
1241
+ attention_mask,
1242
+ **ckpt_kwargs,
1243
+ )
1244
+
1245
+ else:
1246
+ hidden_states = block(
1247
+ hidden_states=hidden_states,
1248
+ temb=temb,
1249
+ image_rotary_emb=image_rotary_emb,
1250
+ attention_mask=attention_mask,
1251
+ )
1252
+
1253
+ # controlnet residual
1254
+ if controlnet_single_block_samples is not None:
1255
+ interval_control = len(self.single_transformer_blocks) / len(
1256
+ controlnet_single_block_samples
1257
+ )
1258
+ interval_control = int(np.ceil(interval_control))
1259
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
1260
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
1261
+ + controlnet_single_block_samples[index_block // interval_control]
1262
+ )
1263
+
1264
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
1265
+
1266
+ hidden_states = self.norm_out(hidden_states, temb)
1267
+ output = self.proj_out(hidden_states)
1268
+
1269
+ if USE_PEFT_BACKEND:
1270
+ # remove `lora_scale` from each PEFT layer
1271
+ unscale_lora_layers(self, lora_scale)
1272
+
1273
+ if not return_dict:
1274
+ return (output,)
1275
+
1276
+ return Transformer2DModelOutput(sample=output)
1277
+
1278
+ ###################################
1279
+ # END TRANS MERGE
1280
+ ####################################
1281
+
1282
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
1283
+ #
1284
+ # Licensed under the Apache License, Version 2.0 (the "License");
1285
+ # you may not use this file except in compliance with the License.
1286
+ # You may obtain a copy of the License at
1287
+ #
1288
+ # http://www.apache.org/licenses/LICENSE-2.0
1289
+ #
1290
+ # Unless required by applicable law or agreed to in writing, software
1291
+ # distributed under the License is distributed on an "AS IS" BASIS,
1292
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1293
+ # See the License for the specific language governing permissions and
1294
+ # limitations under the License.
1295
+ #
1296
+ # This was modied from the control net repo
1297
+
1298
+
1299
+ ####################################
1300
+ ##### CONTROL NET MODEL MERGE ######
1301
+ ####################################
1302
+
1303
+
1304
+ from dataclasses import dataclass
1305
+ from typing import Any, Dict, List, Optional, Tuple, Union
1306
+
1307
+ import torch
1308
+ import torch.nn as nn
1309
+
1310
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
1311
+ from diffusers.loaders import PeftAdapterMixin
1312
+ from diffusers.models.attention_processor import AttentionProcessor
1313
+ from diffusers.models.modeling_utils import ModelMixin
1314
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
1315
+ from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
1316
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
1317
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
1318
+
1319
+
1320
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
1321
+
1322
+
1323
+ @dataclass
1324
+ class FluxControlNetOutput(BaseOutput):
1325
+ controlnet_block_samples: Tuple[torch.Tensor]
1326
+ controlnet_single_block_samples: Tuple[torch.Tensor]
1327
+
1328
+
1329
+ class LibreFluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
1330
+ _supports_gradient_checkpointing = True
1331
+
1332
+ @register_to_config
1333
+ def __init__(
1334
+ self,
1335
+ patch_size: int = 1,
1336
+ in_channels: int = 64,
1337
+ num_layers: int = 19,
1338
+ num_single_layers: int = 38,
1339
+ attention_head_dim: int = 128,
1340
+ num_attention_heads: int = 24,
1341
+ joint_attention_dim: int = 4096,
1342
+ pooled_projection_dim: int = 768,
1343
+ guidance_embeds: bool = False,
1344
+ axes_dims_rope: List[int] = [16, 56, 56],
1345
+ num_mode: int = None,
1346
+ conditioning_embedding_channels: int = None,
1347
+ ):
1348
+ super().__init__()
1349
+ self.out_channels = in_channels
1350
+ self.inner_dim = num_attention_heads * attention_head_dim
1351
+
1352
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
1353
+
1354
+ # edit 19
1355
+ #text_time_guidance_cls = (
1356
+ # CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
1357
+ #)
1358
+
1359
+ text_time_guidance_cls = CombinedTimestepGuidanceTextProjEmbeddings
1360
+ text_time_cls = CombinedTimestepTextProjEmbeddings
1361
+
1362
+ self.time_text_embed = text_time_cls(
1363
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
1364
+ )
1365
+ self.time_text_guidance_embed = text_time_guidance_cls(
1366
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
1367
+ )
1368
+
1369
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
1370
+ self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
1371
+
1372
+ self.transformer_blocks = nn.ModuleList(
1373
+ [
1374
+ FluxTransformerBlock(
1375
+ dim=self.inner_dim,
1376
+ num_attention_heads=num_attention_heads,
1377
+ attention_head_dim=attention_head_dim,
1378
+ )
1379
+ for i in range(num_layers)
1380
+ ]
1381
+ )
1382
+
1383
+ self.single_transformer_blocks = nn.ModuleList(
1384
+ [
1385
+ FluxSingleTransformerBlock(
1386
+ dim=self.inner_dim,
1387
+ num_attention_heads=num_attention_heads,
1388
+ attention_head_dim=attention_head_dim,
1389
+ )
1390
+ for i in range(num_single_layers)
1391
+ ]
1392
+ )
1393
+
1394
+ # controlnet_blocks
1395
+ self.controlnet_blocks = nn.ModuleList([])
1396
+ for _ in range(len(self.transformer_blocks)):
1397
+ self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
1398
+
1399
+ self.controlnet_single_blocks = nn.ModuleList([])
1400
+ for _ in range(len(self.single_transformer_blocks)):
1401
+ self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
1402
+
1403
+ self.union = num_mode is not None
1404
+ if self.union:
1405
+ self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
1406
+
1407
+ if conditioning_embedding_channels is not None:
1408
+ self.input_hint_block = ControlNetConditioningEmbedding(
1409
+ conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
1410
+ )
1411
+ self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
1412
+ else:
1413
+ self.input_hint_block = None
1414
+ self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
1415
+
1416
+ self.gradient_checkpointing = False
1417
+
1418
+ @property
1419
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
1420
+ def attn_processors(self):
1421
+ r"""
1422
+ Returns:
1423
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
1424
+ indexed by its weight name.
1425
+ """
1426
+ # set recursively
1427
+ processors = {}
1428
+
1429
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
1430
+ if hasattr(module, "get_processor"):
1431
+ processors[f"{name}.processor"] = module.get_processor()
1432
+
1433
+ for sub_name, child in module.named_children():
1434
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
1435
+
1436
+ return processors
1437
+
1438
+ for name, module in self.named_children():
1439
+ fn_recursive_add_processors(name, module, processors)
1440
+
1441
+ return processors
1442
+
1443
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
1444
+ def set_attn_processor(self, processor):
1445
+ r"""
1446
+ Sets the attention processor to use to compute attention.
1447
+
1448
+ Parameters:
1449
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
1450
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
1451
+ for **all** `Attention` layers.
1452
+
1453
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
1454
+ processor. This is strongly recommended when setting trainable attention processors.
1455
+
1456
+ """
1457
+ count = len(self.attn_processors.keys())
1458
+
1459
+ if isinstance(processor, dict) and len(processor) != count:
1460
+ raise ValueError(
1461
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
1462
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
1463
+ )
1464
+
1465
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
1466
+ if hasattr(module, "set_processor"):
1467
+ if not isinstance(processor, dict):
1468
+ module.set_processor(processor)
1469
+ else:
1470
+ module.set_processor(processor.pop(f"{name}.processor"))
1471
+
1472
+ for sub_name, child in module.named_children():
1473
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
1474
+
1475
+ for name, module in self.named_children():
1476
+ fn_recursive_attn_processor(name, module, processor)
1477
+
1478
+ def _set_gradient_checkpointing(self, module, value=False):
1479
+ if hasattr(module, "gradient_checkpointing"):
1480
+ module.gradient_checkpointing = value
1481
+
1482
+ @classmethod
1483
+ def from_transformer(
1484
+ cls,
1485
+ transformer,
1486
+ num_layers: int = 4,
1487
+ num_single_layers: int = 10,
1488
+ attention_head_dim: int = 128,
1489
+ num_attention_heads: int = 24,
1490
+ load_weights_from_transformer=True,
1491
+ ):
1492
+ config = dict(transformer.config)
1493
+ config["num_layers"] = num_layers
1494
+ config["num_single_layers"] = num_single_layers
1495
+ config["attention_head_dim"] = attention_head_dim
1496
+ config["num_attention_heads"] = num_attention_heads
1497
+
1498
+ controlnet = cls.from_config(config)
1499
+
1500
+ if load_weights_from_transformer:
1501
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
1502
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
1503
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
1504
+ controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
1505
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
1506
+ controlnet.single_transformer_blocks.load_state_dict(
1507
+ transformer.single_transformer_blocks.state_dict(), strict=False
1508
+ )
1509
+
1510
+ controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
1511
+
1512
+ return controlnet
1513
+
1514
+ # Edit 13 Adding attention masking to forward
1515
+ def forward(
1516
+ self,
1517
+ hidden_states: torch.Tensor,
1518
+ controlnet_cond: torch.Tensor,
1519
+ controlnet_mode: torch.Tensor = None,
1520
+ conditioning_scale: float = 1.0,
1521
+ encoder_hidden_states: torch.Tensor = None,
1522
+ pooled_projections: torch.Tensor = None,
1523
+ timestep: torch.LongTensor = None,
1524
+ img_ids: torch.Tensor = None,
1525
+ txt_ids: torch.Tensor = None,
1526
+ guidance: torch.Tensor = None,
1527
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
1528
+ return_dict: bool = True,
1529
+ attention_mask: Optional[torch.Tensor] = None, # <-- 1. ADD ARGUMENT HERE
1530
+
1531
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
1532
+ """
1533
+ The [`FluxTransformer2DModel`] forward method.
1534
+
1535
+ Args:
1536
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
1537
+ Input `hidden_states`.
1538
+ controlnet_cond (`torch.Tensor`):
1539
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
1540
+ controlnet_mode (`torch.Tensor`):
1541
+ The mode tensor of shape `(batch_size, 1)`.
1542
+ conditioning_scale (`float`, defaults to `1.0`):
1543
+ The scale factor for ControlNet outputs.
1544
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
1545
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
1546
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
1547
+ from the embeddings of input conditions.
1548
+ timestep ( `torch.LongTensor`):
1549
+ Used to indicate denoising step.
1550
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
1551
+ A list of tensors that if specified are added to the residuals of transformer blocks.
1552
+ joint_attention_kwargs (`dict`, *optional*):
1553
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1554
+ `self.processor` in
1555
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1556
+ return_dict (`bool`, *optional*, defaults to `True`):
1557
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
1558
+ tuple.
1559
+
1560
+ Returns:
1561
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
1562
+ `tuple` where the first element is the sample tensor.
1563
+ """
1564
+ if joint_attention_kwargs is not None:
1565
+ joint_attention_kwargs = joint_attention_kwargs.copy()
1566
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
1567
+ else:
1568
+ lora_scale = 1.0
1569
+
1570
+ if USE_PEFT_BACKEND:
1571
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1572
+ scale_lora_layers(self, lora_scale)
1573
+ else:
1574
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
1575
+ logger.warning(
1576
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
1577
+ )
1578
+ hidden_states = self.x_embedder(hidden_states)
1579
+
1580
+ if self.input_hint_block is not None:
1581
+ controlnet_cond = self.input_hint_block(controlnet_cond)
1582
+ batch_size, channels, height_pw, width_pw = controlnet_cond.shape
1583
+ height = height_pw // self.config.patch_size
1584
+ width = width_pw // self.config.patch_size
1585
+ controlnet_cond = controlnet_cond.reshape(
1586
+ batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
1587
+ )
1588
+ controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
1589
+ controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
1590
+ # add
1591
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
1592
+
1593
+ timestep = timestep.to(hidden_states.dtype) * 1000
1594
+ if guidance is not None:
1595
+ guidance = guidance.to(hidden_states.dtype) * 1000
1596
+ else:
1597
+ guidance = None
1598
+
1599
+ #print ('Guidance:', guidance)
1600
+ temb = (
1601
+ self.time_text_embed(timestep, pooled_projections)
1602
+ if guidance is None
1603
+ # edit 19
1604
+ else self.time_text_guidance_embed(timestep, guidance, pooled_projections)
1605
+ )
1606
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
1607
+
1608
+ if self.union:
1609
+ # union mode
1610
+ if controlnet_mode is None:
1611
+ raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
1612
+ # union mode emb
1613
+ controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
1614
+ encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
1615
+ txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
1616
+
1617
+ if txt_ids.ndim == 3:
1618
+ logger.warning(
1619
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
1620
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
1621
+ )
1622
+ txt_ids = txt_ids[0]
1623
+ if img_ids.ndim == 3:
1624
+ logger.warning(
1625
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
1626
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
1627
+ )
1628
+ img_ids = img_ids[0]
1629
+
1630
+ ids = torch.cat((txt_ids, img_ids), dim=0)
1631
+ image_rotary_emb = self.pos_embed(ids)
1632
+
1633
+ block_samples = ()
1634
+ for index_block, block in enumerate(self.transformer_blocks):
1635
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1636
+
1637
+ def create_custom_forward(module, return_dict=None):
1638
+ def custom_forward(*inputs):
1639
+ if return_dict is not None:
1640
+ return module(*inputs, return_dict=return_dict)
1641
+ else:
1642
+ return module(*inputs)
1643
+
1644
+ return custom_forward
1645
+
1646
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1647
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
1648
+ create_custom_forward(block),
1649
+ hidden_states,
1650
+ encoder_hidden_states,
1651
+ temb,
1652
+ image_rotary_emb,
1653
+ attention_mask, # Edit 13
1654
+ **ckpt_kwargs,
1655
+ )
1656
+
1657
+ else:
1658
+ encoder_hidden_states, hidden_states = block(
1659
+ hidden_states=hidden_states,
1660
+ encoder_hidden_states=encoder_hidden_states,
1661
+ temb=temb,
1662
+ image_rotary_emb=image_rotary_emb,
1663
+ attention_mask=attention_mask, # Edit 13
1664
+
1665
+ )
1666
+ block_samples = block_samples + (hidden_states,)
1667
+
1668
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
1669
+
1670
+ single_block_samples = ()
1671
+ for index_block, block in enumerate(self.single_transformer_blocks):
1672
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1673
+
1674
+ def create_custom_forward(module, return_dict=None):
1675
+ def custom_forward(*inputs):
1676
+ if return_dict is not None:
1677
+ return module(*inputs, return_dict=return_dict)
1678
+ else:
1679
+ return module(*inputs)
1680
+
1681
+ return custom_forward
1682
+
1683
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1684
+ hidden_states = torch.utils.checkpoint.checkpoint(
1685
+ create_custom_forward(block),
1686
+ hidden_states,
1687
+ temb,
1688
+ image_rotary_emb,
1689
+ attention_mask, # <-- 2. PASS MASK TO GRADIENT CHECKPOINTING # Edit 13
1690
+ **ckpt_kwargs,
1691
+ )
1692
+
1693
+ else:
1694
+ hidden_states = block(
1695
+ hidden_states=hidden_states,
1696
+ temb=temb,
1697
+ image_rotary_emb=image_rotary_emb,
1698
+ attention_mask=attention_mask, # <-- 2. PASS MASK TO BLOCK Edit 13
1699
+
1700
+ )
1701
+ single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
1702
+
1703
+ # controlnet block
1704
+ controlnet_block_samples = ()
1705
+ for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
1706
+ block_sample = controlnet_block(block_sample)
1707
+ controlnet_block_samples = controlnet_block_samples + (block_sample,)
1708
+
1709
+ controlnet_single_block_samples = ()
1710
+ for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
1711
+ single_block_sample = controlnet_block(single_block_sample)
1712
+ controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
1713
+
1714
+ # scaling
1715
+ controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
1716
+ controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
1717
+
1718
+ controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
1719
+ controlnet_single_block_samples = (
1720
+ None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
1721
+ )
1722
+
1723
+ if USE_PEFT_BACKEND:
1724
+ # remove `lora_scale` from each PEFT layer
1725
+ unscale_lora_layers(self, lora_scale)
1726
+
1727
+ if not return_dict:
1728
+ return (controlnet_block_samples, controlnet_single_block_samples)
1729
+
1730
+ return FluxControlNetOutput(
1731
+ controlnet_block_samples=controlnet_block_samples,
1732
+ controlnet_single_block_samples=controlnet_single_block_samples,
1733
+ )
1734
+
examples/control_ip_example.png ADDED

Git LFS Details

  • SHA256: 8f5c47ab97093c10c362b749c311a5217f8e3f695b8650f9afe3f7ab02ebb441
  • Pointer size: 132 Bytes
  • Size of remote file: 2.12 MB
examples/david.jpg ADDED

Git LFS Details

  • SHA256: 0e6a82b4100d889365ab7e7856bb1f5ebe5b3d8e2dcad6b1f45aa80885e3f49e
  • Pointer size: 131 Bytes
  • Size of remote file: 150 kB
examples/libre_flux_control_image.png ADDED

Git LFS Details

  • SHA256: 8227ba7d8884869f9dc5d3b8950d248038c3dde6181aa45da743ac8651342362
  • Pointer size: 131 Bytes
  • Size of remote file: 149 kB
examples/merc.jpeg ADDED
examples/mona.jpg ADDED

Git LFS Details

  • SHA256: ec04c3335ddf7d9122e63f609885d95b9518c456c11075ab155ad2e8b8c85a00
  • Pointer size: 131 Bytes
  • Size of remote file: 270 kB
flux_ip_adapter.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import chain
2
+ import torch
3
+ from torch import nn
4
+ from diffusers.models.attention_processor import (
5
+ Attention,
6
+ AttentionProcessor,
7
+ )
8
+
9
+ from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
10
+ import torch.nn.functional as F
11
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
12
+ from diffusers.models.attention_processor import Attention
13
+ import inspect
14
+ from functools import partial
15
+ from diffusers.models.normalization import RMSNorm
16
+ from typing import Any, Dict, List, Optional, Union
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+
21
+ class IPFluxAttnProcessor2_0(nn.Module):
22
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
23
+
24
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, num_heads=0):
25
+ super().__init__()
26
+
27
+ self.hidden_size = hidden_size
28
+ self.cross_attention_dim = cross_attention_dim
29
+ self.scale = scale
30
+ self.num_tokens = num_tokens
31
+
32
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
33
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
34
+
35
+ self.norm_added_k = RMSNorm(128, eps=1e-5, elementwise_affine=False)
36
+
37
+ def __call__(
38
+ self,
39
+ attn,
40
+ hidden_states: torch.FloatTensor,
41
+ encoder_hidden_states: torch.FloatTensor = None,
42
+ ip_encoder_hidden_states: torch.FloatTensor = None,
43
+ attention_mask: Optional[torch.FloatTensor] = None,
44
+ image_rotary_emb: Optional[torch.Tensor] = None,
45
+ layer_scale: Optional[torch.Tensor] = None,
46
+ ) -> torch.FloatTensor:
47
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
48
+
49
+ ip_hidden_states = ip_encoder_hidden_states
50
+
51
+ # `sample` projections.
52
+ query = attn.to_q(hidden_states)
53
+ key = attn.to_k(hidden_states)
54
+ value = attn.to_v(hidden_states)
55
+
56
+ inner_dim = key.shape[-1]
57
+ head_dim = inner_dim // attn.heads
58
+
59
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
60
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
61
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
62
+
63
+ if attn.norm_q is not None:
64
+ query = attn.norm_q(query)
65
+ if attn.norm_k is not None:
66
+ key = attn.norm_k(key)
67
+
68
+ # handle IP attention FIRST
69
+
70
+
71
+ # for ip-adapter
72
+ if ip_hidden_states != None:
73
+ ip_key = self.to_k_ip(ip_hidden_states)
74
+ ip_value = self.to_v_ip(ip_hidden_states)
75
+
76
+ # reshaping to match query shape
77
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
78
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
79
+
80
+ ip_key = self.norm_added_k(ip_key)
81
+
82
+
83
+ # Using flux stype attention here
84
+ ip_hidden_states = F.scaled_dot_product_attention(
85
+ query,
86
+ ip_key,
87
+ ip_value,
88
+ dropout_p=0.0,
89
+ is_causal=False,
90
+ attn_mask=None,
91
+ )
92
+
93
+ # reshaping ip_hidden_states in the same way as hidden_states
94
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
95
+ batch_size, -1, attn.heads * head_dim
96
+ )
97
+
98
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
99
+ if encoder_hidden_states is not None:
100
+ # `context` projections.
101
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
102
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
103
+
104
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
105
+
106
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
107
+ batch_size, -1, attn.heads, head_dim
108
+ ).transpose(1, 2)
109
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
110
+ batch_size, -1, attn.heads, head_dim
111
+ ).transpose(1, 2)
112
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
113
+ batch_size, -1, attn.heads, head_dim
114
+ ).transpose(1, 2)
115
+
116
+ if attn.norm_added_q is not None:
117
+ encoder_hidden_states_query_proj = attn.norm_added_q(
118
+ encoder_hidden_states_query_proj
119
+ )
120
+ if attn.norm_added_k is not None:
121
+ encoder_hidden_states_key_proj = attn.norm_added_k(
122
+ encoder_hidden_states_key_proj
123
+ )
124
+
125
+ # attention
126
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
127
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
128
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
129
+
130
+ if image_rotary_emb is not None:
131
+ from diffusers.models.embeddings import apply_rotary_emb
132
+ query = apply_rotary_emb(query, image_rotary_emb)
133
+
134
+ key = apply_rotary_emb(key, image_rotary_emb)
135
+
136
+ if attention_mask is not None:
137
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
138
+ attention_mask = (attention_mask > 0).bool()
139
+ attention_mask = attention_mask.to(
140
+ device=hidden_states.device, dtype=query.dtype
141
+ )
142
+ original_hidden_states = hidden_states
143
+
144
+ hidden_states = F.scaled_dot_product_attention(
145
+ query,
146
+ key,
147
+ value,
148
+ dropout_p=0.0,
149
+ is_causal=False,
150
+ attn_mask=attention_mask,
151
+ )
152
+
153
+ hidden_states = hidden_states.transpose(1, 2).reshape(
154
+ batch_size, -1, attn.heads * head_dim
155
+ )
156
+ hidden_states = hidden_states.to(query.dtype)
157
+
158
+
159
+ layer_scale = layer_scale.view(-1, 1, 1)
160
+
161
+ if encoder_hidden_states is not None:
162
+
163
+ encoder_hidden_states, hidden_states = (
164
+ hidden_states[:, : encoder_hidden_states.shape[1]],
165
+ hidden_states[:, encoder_hidden_states.shape[1] :],
166
+ )
167
+
168
+ # Final injection of ip addapter hidden_states
169
+ if ip_hidden_states != None:
170
+ hidden_states = hidden_states + (self.scale * layer_scale) * ip_hidden_states
171
+
172
+ # linear proj
173
+ hidden_states = attn.to_out[0](hidden_states)
174
+ # dropout
175
+ hidden_states = attn.to_out[1](hidden_states)
176
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
177
+
178
+ return hidden_states, encoder_hidden_states
179
+
180
+ else:
181
+
182
+ # Final injection of ip addapter hidden_states
183
+ if ip_hidden_states != None:
184
+ hidden_states = hidden_states + (self.scale * layer_scale) * ip_hidden_states
185
+
186
+ if attn.to_out is not None:
187
+ hidden_states = attn.to_out[0](hidden_states)
188
+ hidden_states = attn.to_out[1](hidden_states)
189
+
190
+ return hidden_states
191
+
192
+
193
+ class ImageProjModel(nn.Module):
194
+ def __init__(self, clip_dim=768, cross_attention_dim=4096, num_tokens=16):
195
+ super().__init__()
196
+
197
+ self.num_tokens = num_tokens
198
+ self.cross_attention_dim = cross_attention_dim
199
+ self.clip_dim = clip_dim
200
+
201
+ self.proj = torch.nn.Sequential(
202
+ torch.nn.Linear(clip_dim,clip_dim*2),
203
+ torch.nn.GELU(),
204
+ torch.nn.Linear(clip_dim*2, cross_attention_dim*num_tokens),
205
+ )
206
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
207
+
208
+ def forward(self,input):
209
+
210
+ raw_proj = self.proj(input)
211
+ reshaped_proj = raw_proj.reshape(input.shape[0],self.num_tokens,self.cross_attention_dim)
212
+ reshaped_proj = self.norm( reshaped_proj )
213
+
214
+ return reshaped_proj
215
+
216
+
217
+ class LibreFluxIPAdapter(nn.Module):
218
+ def __init__(self, transformer, image_proj_model, checkpoint=None):
219
+ super().__init__()
220
+ self.transformer = transformer
221
+ self.image_proj_model = image_proj_model
222
+
223
+ # Using startswith uses only double transformer blocks, and skips the single transformer blocks
224
+ self.culled_transformer_blocks = {}
225
+ for name, module in self.transformer.named_modules():
226
+ if isinstance(module, Attention):
227
+ if name.startswith('transformer_blocks') or name.startswith('single_transformer_blocks'):
228
+ #print (f"Using Transformer: {name}")
229
+ self.culled_transformer_blocks[name] = module
230
+ #else:
231
+ #print (f"Ignoring Transformer: {name}")
232
+ # Apply the adapter to the culled blocks
233
+ self.wrap_attention_blocks()
234
+
235
+ if checkpoint:
236
+ self.load_from_checkpoint(checkpoint)
237
+
238
+ def wrap_attention_blocks(self,scale=1.0, num_tokens=16):
239
+ """ Inject the IP-Adapter modules into the Transformer model """
240
+ sample_attn = self.transformer.transformer_blocks[0].attn
241
+
242
+ hidden_size = sample_attn.inner_dim
243
+ cross_attention_dim = sample_attn.cross_attention_dim
244
+ num_heads = sample_attn.heads
245
+ scale = 1.0
246
+ num_tokens = 16
247
+
248
+ processor_list = []
249
+ for name in self.culled_transformer_blocks:
250
+ module = self.culled_transformer_blocks[name]
251
+ module.processor = IPFluxAttnProcessor2_0(
252
+ hidden_size= hidden_size,
253
+ cross_attention_dim=4096,
254
+ num_heads=num_heads,
255
+ scale=1.0,
256
+ num_tokens=16,
257
+ )
258
+ processor_list.append(module.processor )
259
+ lay_count = len(processor_list)
260
+ print (f"Added Attention IP Wrapper to {lay_count} layers")
261
+
262
+ # Store adapters as a module list for saving/loading
263
+ self.adapter_modules = torch.nn.ModuleList(processor_list)
264
+
265
+ def parameters(self):
266
+ """ Easy way to return all params """
267
+ # Apply adapter
268
+ adapter_param_list = []
269
+ for name in self.culled_transformer_blocks:
270
+ module = self.culled_transformer_blocks[name]
271
+ adapter_param_list.append(module.processor.parameters())
272
+
273
+ all_params = chain(*adapter_param_list,self.image_proj_model.parameters())
274
+ return all_params
275
+
276
+ def forward(self, ref_image, *args, layer_scale= torch.Tensor([1.0]), **kwargs):
277
+ """ Run projection and run forward """
278
+ mod_dtype = next(self.image_proj_model.parameters()).dtype
279
+ mod_device = next(self.image_proj_model.parameters()).device
280
+
281
+ ip_encoder_hidden_states = None
282
+ if ref_image != None:
283
+ ip_encoder_hidden_states = self.image_proj_model(ref_image)
284
+
285
+ # Add ip hidden states to kwargs
286
+ if 'joint_attention_kwargs' not in kwargs:
287
+ kwargs['joint_attention_kwargs'] = {}
288
+ layer_scale = layer_scale.to(dtype=mod_dtype,
289
+ device=mod_device)
290
+
291
+ kwargs['joint_attention_kwargs']['ip_layer_scale'] = layer_scale
292
+ kwargs['joint_attention_kwargs']['ip_hidden_states'] = ip_encoder_hidden_states
293
+
294
+ output = self.transformer(*args,
295
+ **kwargs)
296
+
297
+ return output
298
+
299
+ def save_pretrained(self,ckpt_path):
300
+ """ Save model weights """
301
+ state_dict = {}
302
+
303
+ state_dict["image_proj"] = self.image_proj_model.state_dict()
304
+ state_dict["ip_adapter"] = self.adapter_modules.state_dict()
305
+ torch.save(state_dict, ckpt_path)
306
+
307
+ def load_from_checkpoint(self, ckpt_path):
308
+ """ Loader ripped from tencent repo """
309
+ # Calculate original checksums
310
+ orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
311
+ orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
312
+
313
+ state_dict = torch.load(ckpt_path, map_location="cpu")
314
+
315
+ # Load state dict for image_proj_model and adapter_modules
316
+ self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
317
+ self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
318
+
319
+ # Calculate new checksums
320
+ new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
321
+ new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
322
+
323
+ # Verify if the weights have changed
324
+ assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
325
+ assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
326
+
327
+ print(f"Successfully loaded weights from checkpoint {ckpt_path}")
328
+
329
+
330
+ @property
331
+ def dtype(self):
332
+ return next(self.image_proj_model.parameters()).dtype
image_encoder/config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SiglipVisionModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "dtype": "float32",
7
+ "hidden_act": "gelu_pytorch_tanh",
8
+ "hidden_size": 1152,
9
+ "image_size": 384,
10
+ "intermediate_size": 4304,
11
+ "layer_norm_eps": 1e-06,
12
+ "model_type": "siglip_vision_model",
13
+ "num_attention_heads": 16,
14
+ "num_channels": 3,
15
+ "num_hidden_layers": 27,
16
+ "patch_size": 14,
17
+ "transformers_version": "4.57.1"
18
+ }
image_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:979cc297d21bc2634d7fc0e3fdc2589b0c891acf380708ffdeb1f988e3a2d817
3
+ size 1712957296
image_encoder/preprocessor_config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_convert_rgb": null,
3
+ "do_normalize": true,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.5,
8
+ 0.5,
9
+ 0.5
10
+ ],
11
+ "image_processor_type": "SiglipImageProcessor",
12
+ "image_std": [
13
+ 0.5,
14
+ 0.5,
15
+ 0.5
16
+ ],
17
+ "processor_class": "SiglipProcessor",
18
+ "resample": 3,
19
+ "rescale_factor": 0.00392156862745098,
20
+ "size": {
21
+ "height": 384,
22
+ "width": 384
23
+ }
24
+ }
image_encoder/special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": {
3
+ "content": "</s>",
4
+ "lstrip": true,
5
+ "normalized": false,
6
+ "rstrip": true,
7
+ "single_word": false
8
+ },
9
+ "pad_token": {
10
+ "content": "</s>",
11
+ "lstrip": true,
12
+ "normalized": false,
13
+ "rstrip": true,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<unk>",
18
+ "lstrip": true,
19
+ "normalized": false,
20
+ "rstrip": true,
21
+ "single_word": false
22
+ }
23
+ }
image_encoder/spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e5036bed065526c3c212dfbe288752391797c4bb1a284aa18c9a0b23fcaf8ec
3
+ size 798330
image_encoder/tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "1": {
4
+ "content": "</s>",
5
+ "lstrip": true,
6
+ "normalized": false,
7
+ "rstrip": true,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "2": {
12
+ "content": "<unk>",
13
+ "lstrip": true,
14
+ "normalized": false,
15
+ "rstrip": true,
16
+ "single_word": false,
17
+ "special": true
18
+ }
19
+ },
20
+ "additional_special_tokens": [],
21
+ "clean_up_tokenization_spaces": true,
22
+ "do_lower_case": true,
23
+ "eos_token": "</s>",
24
+ "extra_special_tokens": {},
25
+ "model_input_names": [
26
+ "input_ids"
27
+ ],
28
+ "model_max_length": 64,
29
+ "pad_token": "</s>",
30
+ "processor_class": "SiglipProcessor",
31
+ "sp_model_kwargs": {},
32
+ "tokenizer_class": "SiglipTokenizer",
33
+ "unk_token": "<unk>"
34
+ }
ip_adapter.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a1fcb0c5d436335ae332704585105d547dd4908832cfd361d7a5a8101eefcd0
3
+ size 5291247917
model_index.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "LibreFluxIPAdapterPipeline",
3
+ "_diffusers_version": "0.35.2",
4
+ "controlnet": [
5
+ "net",
6
+ "LibreFluxControlNetModel"
7
+ ],
8
+ "image_encoder": [
9
+ "transformers",
10
+ "SiglipVisionModel"
11
+ ],
12
+ "scheduler": [
13
+ "diffusers",
14
+ "FlowMatchEulerDiscreteScheduler"
15
+ ],
16
+ "text_encoder": [
17
+ "transformers",
18
+ "CLIPTextModel"
19
+ ],
20
+ "text_encoder_2": [
21
+ "transformers",
22
+ "T5EncoderModel"
23
+ ],
24
+ "tokenizer": [
25
+ "transformers",
26
+ "CLIPTokenizer"
27
+ ],
28
+ "tokenizer_2": [
29
+ "transformers",
30
+ "T5TokenizerFast"
31
+ ],
32
+ "transformer": [
33
+ "trans",
34
+ "LibreFluxTransformer2DModel"
35
+ ],
36
+ "vae": [
37
+ "diffusers",
38
+ "AutoencoderKL"
39
+ ]
40
+ }
pipeline.py ADDED
@@ -0,0 +1,1227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @title
2
+ # Copyright 2024 Stability AI, The HuggingFace Team, The InstantX Team, and Terminus Research Group. All rights reserved.
3
+ #
4
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ #
18
+ # Originally licensed under the Apache License, Version 2.0 (the "License");
19
+ # Updated to "Affero GENERAL PUBLIC LICENSE Version 3, 19 November 2007" via extensive updates to attn_mask usage.
20
+
21
+ from typing import Any, Dict, List, Optional, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+
27
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
28
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
29
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
30
+ from diffusers.models.attention import FeedForward
31
+ from diffusers.models.attention_processor import Attention
32
+
33
+ from diffusers.models.modeling_utils import ModelMixin
34
+ from diffusers.models.normalization import (
35
+ AdaLayerNormContinuous,
36
+ AdaLayerNormZero,
37
+ AdaLayerNormZeroSingle,
38
+ )
39
+ from diffusers.utils import (
40
+ USE_PEFT_BACKEND,
41
+ is_torch_version,
42
+ logging,
43
+ scale_lora_layers,
44
+ unscale_lora_layers,
45
+ )
46
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
47
+ from diffusers.models.embeddings import (
48
+ CombinedTimestepGuidanceTextProjEmbeddings,
49
+ CombinedTimestepTextProjEmbeddings,
50
+ )
51
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
52
+
53
+ from dataclasses import dataclass
54
+ from typing import List, Union
55
+ import PIL.Image
56
+ from diffusers.utils import BaseOutput
57
+
58
+ import inspect
59
+ from functools import lru_cache
60
+ from typing import Any, Callable, Dict, List, Optional, Union
61
+
62
+ import numpy as np
63
+ import torch
64
+ from transformers import (
65
+ CLIPTextModel,
66
+ CLIPTokenizer,
67
+ T5EncoderModel,
68
+ T5TokenizerFast,
69
+ CLIPVisionModelWithProjection,
70
+ CLIPTextModelWithProjection,
71
+ CLIPImageProcessor
72
+ )
73
+
74
+ from diffusers.image_processor import VaeImageProcessor
75
+ from diffusers.loaders import SD3LoraLoaderMixin
76
+ from diffusers.models.autoencoders import AutoencoderKL
77
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
78
+ from diffusers.utils import (
79
+ USE_PEFT_BACKEND,
80
+ is_torch_xla_available,
81
+ logging,
82
+ replace_example_docstring,
83
+ scale_lora_layers,
84
+ unscale_lora_layers,
85
+ )
86
+ from diffusers.utils.torch_utils import randn_tensor
87
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
88
+
89
+ from PIL import Image
90
+
91
+ from .transformer.trans import *
92
+ from .flux_ip_adapter import *
93
+ from .controlnet.net import LibreFluxControlNetModel
94
+
95
+ if is_torch_xla_available():
96
+ import torch_xla.core.xla_model as xm
97
+
98
+ XLA_AVAILABLE = True
99
+ else:
100
+ XLA_AVAILABLE = False
101
+
102
+
103
+ @dataclass
104
+ class FluxPipelineOutput(BaseOutput):
105
+ """
106
+ Output class for Stable Diffusion pipelines.
107
+ Args:
108
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
109
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
110
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
111
+ """
112
+
113
+ images: Union[List[PIL.Image.Image], np.ndarray]
114
+
115
+
116
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
117
+
118
+
119
+ EXAMPLE_DOC_STRING = """
120
+ Examples:
121
+ ```py
122
+ >>> import torch
123
+ >>> from diffusers import FluxPipeline
124
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
125
+ >>> pipe.to("cuda")
126
+ >>> prompt = "A cat holding a sign that says hello world"
127
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
128
+ >>> # Refer to the pipeline documentation for more details.
129
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
130
+ >>> image.save("flux.png")
131
+ ```
132
+ """
133
+
134
+
135
+ def calculate_shift(
136
+ image_seq_len,
137
+ base_seq_len: int = 256,
138
+ max_seq_len: int = 4096,
139
+ base_shift: float = 0.5,
140
+ max_shift: float = 1.16,
141
+ ):
142
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
143
+ b = base_shift - m * base_seq_len
144
+ mu = image_seq_len * m + b
145
+ return mu
146
+
147
+
148
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
149
+ def retrieve_timesteps(
150
+ scheduler,
151
+ num_inference_steps: Optional[int] = None,
152
+ device: Optional[Union[str, torch.device]] = None,
153
+ timesteps: Optional[List[int]] = None,
154
+ sigmas: Optional[List[float]] = None,
155
+ **kwargs,
156
+ ):
157
+ """
158
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
159
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
160
+ Args:
161
+ scheduler (`SchedulerMixin`):
162
+ The scheduler to get timesteps from.
163
+ num_inference_steps (`int`):
164
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
165
+ must be `None`.
166
+ device (`str` or `torch.device`, *optional*):
167
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
168
+ timesteps (`List[int]`, *optional*):
169
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
170
+ `num_inference_steps` and `sigmas` must be `None`.
171
+ sigmas (`List[float]`, *optional*):
172
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
173
+ `num_inference_steps` and `timesteps` must be `None`.
174
+ Returns:
175
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
176
+ second element is the number of inference steps.
177
+ """
178
+ if timesteps is not None and sigmas is not None:
179
+ raise ValueError(
180
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
181
+ )
182
+ if timesteps is not None:
183
+ accepts_timesteps = "timesteps" in set(
184
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
185
+ )
186
+ if not accepts_timesteps:
187
+ raise ValueError(
188
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
189
+ f" timestep schedules. Please check whether you are using the correct scheduler."
190
+ )
191
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
192
+ timesteps = scheduler.timesteps
193
+ num_inference_steps = len(timesteps)
194
+ elif sigmas is not None:
195
+ accept_sigmas = "sigmas" in set(
196
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
197
+ )
198
+ if not accept_sigmas:
199
+ raise ValueError(
200
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
201
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
202
+ )
203
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
204
+ timesteps = scheduler.timesteps
205
+ num_inference_steps = len(timesteps)
206
+ else:
207
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
208
+ timesteps = scheduler.timesteps
209
+ return timesteps, num_inference_steps
210
+
211
+
212
+ class LibreFluxIpAdapterPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
213
+ r"""
214
+ The Flux pipeline for text-to-image generation.
215
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
216
+ Args:
217
+ transformer ([`LibreFluxTransformer2DModel`]):
218
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
219
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
220
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
221
+ vae ([`AutoencoderKL`]):
222
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
223
+ text_encoder ([`CLIPTextModelWithProjection`]):
224
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
225
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
226
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
227
+ as its dimension.
228
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
229
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
230
+ specifically the
231
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
232
+ variant.
233
+ tokenizer (`CLIPTokenizer`):
234
+ Tokenizer of class
235
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
236
+ tokenizer_2 (`CLIPTokenizer`):
237
+ Second Tokenizer of class
238
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
239
+ """
240
+
241
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
242
+ _optional_components = ["ip_adapter"]
243
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
244
+
245
+ def __init__(
246
+ self,
247
+ scheduler: FlowMatchEulerDiscreteScheduler,
248
+ vae: AutoencoderKL,
249
+ text_encoder: CLIPTextModel,
250
+ tokenizer: CLIPTokenizer,
251
+ text_encoder_2: T5EncoderModel,
252
+ tokenizer_2: T5TokenizerFast,
253
+ transformer: LibreFluxTransformer2DModel,
254
+ image_encoder: CLIPVisionModelWithProjection,
255
+ controlnet: Union[
256
+ LibreFluxControlNetModel, List[LibreFluxControlNetModel], Tuple[LibreFluxControlNetModel],
257
+ ],
258
+ ip_adapter: Optional[LibreFluxIPAdapter] = None
259
+ ):
260
+ super().__init__()
261
+
262
+ image_proj_model = ImageProjModel( clip_dim = image_encoder.config.hidden_size,
263
+ cross_attention_dim=4096,
264
+ num_tokens=128)
265
+
266
+ ip_adapter = LibreFluxIPAdapter(transformer,
267
+ image_proj_model)
268
+
269
+ self.ip_loaded = False
270
+
271
+ self.register_modules(
272
+ vae=vae,
273
+ text_encoder=text_encoder,
274
+ text_encoder_2=text_encoder_2,
275
+ tokenizer=tokenizer,
276
+ tokenizer_2=tokenizer_2,
277
+ transformer=transformer,
278
+ scheduler=scheduler,
279
+ image_encoder=image_encoder,
280
+ controlnet=controlnet,
281
+ ip_adapter=ip_adapter # <-- Now registered
282
+ )
283
+ self.vae_scale_factor = (
284
+ 2 ** (len(self.vae.config.block_out_channels))
285
+ if hasattr(self, "vae") and self.vae is not None
286
+ else 16
287
+ )
288
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
289
+ self.tokenizer_max_length = (
290
+ self.tokenizer.model_max_length
291
+ if hasattr(self, "tokenizer") and self.tokenizer is not None
292
+ else 77
293
+ )
294
+ self.default_sample_size = 64
295
+
296
+ #self.clip_image_processor = CLIPImageProcessor()
297
+ from transformers import AutoProcessor, SiglipVisionModel
298
+ self.clip_image_processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384")
299
+
300
+ def _get_t5_prompt_embeds(
301
+ self,
302
+ prompt: Union[str, List[str]] = None,
303
+ num_images_per_prompt: int = 1,
304
+ max_sequence_length: int = 512,
305
+ device: Optional[torch.device] = None,
306
+ dtype: Optional[torch.dtype] = None,
307
+ ):
308
+ device = device or self._execution_device
309
+ dtype = dtype or self.text_encoder.dtype
310
+
311
+ prompt = [prompt] if isinstance(prompt, str) else prompt
312
+ batch_size = len(prompt)
313
+
314
+ text_inputs = self.tokenizer_2(
315
+ prompt,
316
+ padding="max_length",
317
+ max_length=max_sequence_length,
318
+ truncation=True,
319
+ return_length=False,
320
+ return_overflowing_tokens=False,
321
+ return_tensors="pt",
322
+ )
323
+ prompt_attention_mask = text_inputs.attention_mask
324
+ text_input_ids = text_inputs.input_ids
325
+ untruncated_ids = self.tokenizer_2(
326
+ prompt, padding="longest", return_tensors="pt"
327
+ ).input_ids
328
+
329
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
330
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
331
+ logger.warning(
332
+ "The following part of your input was truncated because `max_sequence_length` is set to "
333
+ f" {max_sequence_length} tokens: {removed_text}"
334
+ )
335
+
336
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
337
+
338
+ dtype = self.text_encoder_2.dtype
339
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
340
+
341
+ _, seq_len, _ = prompt_embeds.shape
342
+
343
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
344
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
345
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
346
+
347
+ return prompt_embeds, prompt_attention_mask
348
+
349
+ def prepare_image(
350
+ self,
351
+ image,
352
+ width,
353
+ height,
354
+ batch_size,
355
+ num_images_per_prompt,
356
+ device,
357
+ dtype,
358
+ do_classifier_free_guidance=False,
359
+ guess_mode=False,
360
+ ):
361
+ if isinstance(image, torch.Tensor):
362
+ pass
363
+ else:
364
+ image = self.image_processor.preprocess(image, height=height, width=width)
365
+
366
+ image_batch_size = image.shape[0]
367
+
368
+ if image_batch_size == 1:
369
+ repeat_by = batch_size
370
+ else:
371
+ # image batch size is the same as prompt batch size
372
+ repeat_by = num_images_per_prompt
373
+
374
+ image = image.repeat_interleave(repeat_by, dim=0)
375
+
376
+ image = image.to(device=device, dtype=dtype)
377
+
378
+ if do_classifier_free_guidance and not guess_mode:
379
+ image = torch.cat([image] * 2)
380
+
381
+ return image
382
+
383
+ def _get_clip_prompt_embeds(
384
+ self,
385
+ prompt: Union[str, List[str]],
386
+ num_images_per_prompt: int = 1,
387
+ device: Optional[torch.device] = None,
388
+ ):
389
+ device = device or self._execution_device
390
+
391
+ prompt = [prompt] if isinstance(prompt, str) else prompt
392
+ batch_size = len(prompt)
393
+
394
+ text_inputs = self.tokenizer(
395
+ prompt,
396
+ padding="max_length",
397
+ max_length=self.tokenizer_max_length,
398
+ truncation=True,
399
+ return_overflowing_tokens=False,
400
+ return_length=False,
401
+ return_tensors="pt",
402
+ )
403
+
404
+ text_input_ids = text_inputs.input_ids
405
+ untruncated_ids = self.tokenizer(
406
+ prompt, padding="longest", return_tensors="pt"
407
+ ).input_ids
408
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
409
+ text_input_ids, untruncated_ids
410
+ ):
411
+ removed_text = self.tokenizer.batch_decode(
412
+ untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
413
+ )
414
+ logger.warning(
415
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
416
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
417
+ )
418
+ prompt_embeds = self.text_encoder(
419
+ text_input_ids.to(device), output_hidden_states=False
420
+ )
421
+
422
+ # Use pooled output of CLIPTextModel
423
+ prompt_embeds = prompt_embeds.pooler_output
424
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
425
+
426
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
427
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
428
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
429
+
430
+ return prompt_embeds
431
+
432
+ @lru_cache(maxsize=128)
433
+ def encode_prompt(
434
+ self,
435
+ prompt: Union[str, List[str]],
436
+ prompt_2: Union[str, List[str]],
437
+ device: Optional[torch.device] = None,
438
+ num_images_per_prompt: int = 1,
439
+ prompt_embeds: Optional[torch.FloatTensor] = None,
440
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
441
+ max_sequence_length: int = 512,
442
+ lora_scale: Optional[float] = None,
443
+ ):
444
+ r"""
445
+ Args:
446
+ prompt (`str` or `List[str]`, *optional*):
447
+ prompt to be encoded
448
+ prompt_2 (`str` or `List[str]`, *optional*):
449
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
450
+ used in all text-encoders
451
+ device: (`torch.device`):
452
+ torch device
453
+ num_images_per_prompt (`int`):
454
+ number of images that should be generated per prompt
455
+ prompt_embeds (`torch.FloatTensor`, *optional*):
456
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
457
+ provided, text embeddings will be generated from `prompt` input argument.
458
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
459
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
460
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
461
+ clip_skip (`int`, *optional*):
462
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
463
+ the output of the pre-final layer will be used for computing the prompt embeddings.
464
+ lora_scale (`float`, *optional*):
465
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
466
+ """
467
+ device = device or self._execution_device
468
+
469
+ # set lora scale so that monkey patched LoRA
470
+ # function of text encoder can correctly access it
471
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
472
+ self._lora_scale = lora_scale
473
+
474
+ # dynamically adjust the LoRA scale
475
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
476
+ scale_lora_layers(self.text_encoder, lora_scale)
477
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
478
+ scale_lora_layers(self.text_encoder_2, lora_scale)
479
+
480
+ prompt = [prompt] if isinstance(prompt, str) else prompt
481
+ if prompt is not None:
482
+ batch_size = len(prompt)
483
+ else:
484
+ batch_size = prompt_embeds.shape[0]
485
+
486
+ prompt_attention_mask = None
487
+ if prompt_embeds is None:
488
+ prompt_2 = prompt_2 or prompt
489
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
490
+
491
+ # We only use the pooled prompt output from the CLIPTextModel
492
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
493
+ prompt=prompt,
494
+ device=device,
495
+ num_images_per_prompt=num_images_per_prompt,
496
+ )
497
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
498
+ prompt=prompt_2,
499
+ num_images_per_prompt=num_images_per_prompt,
500
+ max_sequence_length=max_sequence_length,
501
+ device=device,
502
+ )
503
+
504
+ if self.text_encoder is not None:
505
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
506
+ # Retrieve the original scale by scaling back the LoRA layers
507
+ unscale_lora_layers(self.text_encoder, lora_scale)
508
+
509
+ if self.text_encoder_2 is not None:
510
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
511
+ # Retrieve the original scale by scaling back the LoRA layers
512
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
513
+
514
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
515
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
516
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
517
+
518
+ return prompt_embeds, pooled_prompt_embeds, text_ids, prompt_attention_mask
519
+
520
+ def check_inputs(
521
+ self,
522
+ prompt,
523
+ prompt_2,
524
+ height,
525
+ width,
526
+ prompt_embeds=None,
527
+ pooled_prompt_embeds=None,
528
+ callback_on_step_end_tensor_inputs=None,
529
+ max_sequence_length=None,
530
+ ):
531
+ if height % 8 != 0 or width % 8 != 0:
532
+ raise ValueError(
533
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
534
+ )
535
+
536
+ if callback_on_step_end_tensor_inputs is not None and not all(
537
+ k in self._callback_tensor_inputs
538
+ for k in callback_on_step_end_tensor_inputs
539
+ ):
540
+ raise ValueError(
541
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
542
+ )
543
+
544
+ if prompt is not None and prompt_embeds is not None:
545
+ raise ValueError(
546
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
547
+ " only forward one of the two."
548
+ )
549
+ elif prompt_2 is not None and prompt_embeds is not None:
550
+ raise ValueError(
551
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
552
+ " only forward one of the two."
553
+ )
554
+ elif prompt is None and prompt_embeds is None:
555
+ raise ValueError(
556
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
557
+ )
558
+ elif prompt is not None and (
559
+ not isinstance(prompt, str) and not isinstance(prompt, list)
560
+ ):
561
+ raise ValueError(
562
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
563
+ )
564
+ elif prompt_2 is not None and (
565
+ not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
566
+ ):
567
+ raise ValueError(
568
+ f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}"
569
+ )
570
+
571
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
572
+ raise ValueError(
573
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
574
+ )
575
+
576
+ if max_sequence_length is not None and max_sequence_length > 512:
577
+ raise ValueError(
578
+ f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}"
579
+ )
580
+
581
+ @staticmethod
582
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
583
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
584
+ latent_image_ids[..., 1] = (
585
+ latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
586
+ )
587
+ latent_image_ids[..., 2] = (
588
+ latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
589
+ )
590
+
591
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
592
+ latent_image_ids.shape
593
+ )
594
+
595
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
596
+ latent_image_ids = latent_image_ids.reshape(
597
+ batch_size,
598
+ latent_image_id_height * latent_image_id_width,
599
+ latent_image_id_channels,
600
+ )
601
+
602
+ return latent_image_ids.to(dtype=dtype, device=device)
603
+
604
+ @staticmethod
605
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
606
+ latents = latents.view(
607
+ batch_size, num_channels_latents, height // 2, 2, width // 2, 2
608
+ )
609
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
610
+ latents = latents.reshape(
611
+ batch_size, (height // 2) * (width // 2), num_channels_latents * 4
612
+ )
613
+
614
+ return latents
615
+
616
+ @staticmethod
617
+ def _unpack_latents(latents, height, width, vae_scale_factor):
618
+ batch_size, num_patches, channels = latents.shape
619
+
620
+ height = height // vae_scale_factor
621
+ width = width // vae_scale_factor
622
+
623
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
624
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
625
+
626
+ latents = latents.reshape(
627
+ batch_size, channels // (2 * 2), height * 2, width * 2
628
+ )
629
+
630
+ return latents
631
+
632
+ def prepare_latents(
633
+ self,
634
+ batch_size,
635
+ num_channels_latents,
636
+ height,
637
+ width,
638
+ dtype,
639
+ device,
640
+ generator,
641
+ latents=None,
642
+ ):
643
+ height = 2 * (int(height) // self.vae_scale_factor)
644
+ width = 2 * (int(width) // self.vae_scale_factor)
645
+
646
+ shape = (batch_size, num_channels_latents, height, width)
647
+
648
+ if latents is not None:
649
+ latent_image_ids = self._prepare_latent_image_ids(
650
+ batch_size, height, width, device, dtype
651
+ )
652
+ return latents, latent_image_ids
653
+
654
+ if isinstance(generator, list) and len(generator) != batch_size:
655
+ raise ValueError(
656
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
657
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
658
+ )
659
+
660
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
661
+ latents = self._pack_latents(
662
+ latents, batch_size, num_channels_latents, height, width
663
+ )
664
+
665
+ latent_image_ids = self._prepare_latent_image_ids(
666
+ batch_size, height, width, device, dtype
667
+ )
668
+
669
+ return latents, latent_image_ids
670
+
671
+ @property
672
+ def guidance_scale(self):
673
+ return self._guidance_scale
674
+
675
+ @property
676
+ def joint_attention_kwargs(self):
677
+ return self._joint_attention_kwargs
678
+
679
+ @property
680
+ def num_timesteps(self):
681
+ return self._num_timesteps
682
+
683
+ @property
684
+ def interrupt(self):
685
+ return self._interrupt
686
+
687
+ @torch.no_grad()
688
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
689
+ def __call__(
690
+ self,
691
+ prompt: Union[str, List[str]] = None,
692
+ prompt_mask: Optional[Union[torch.FloatTensor, List[torch.FloatTensor]]] = None,
693
+ negative_mask: Optional[
694
+ Union[torch.FloatTensor, List[torch.FloatTensor]]
695
+ ] = None,
696
+ prompt_2: Optional[Union[str, List[str]]] = None,
697
+ height: Optional[int] = None,
698
+ width: Optional[int] = None,
699
+ num_inference_steps: int = 28,
700
+ timesteps: List[int] = None,
701
+ guidance_scale: float = 3.5,
702
+ control_image: PipelineImageInput = None,
703
+ control_mode: Optional[Union[int, List[int]]] = None,
704
+ control_image_undo_centering: bool = False,
705
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
706
+ num_images_per_prompt: Optional[int] = 1,
707
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
708
+ latents: Optional[torch.FloatTensor] = None,
709
+ prompt_embeds: Optional[torch.FloatTensor] = None,
710
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
711
+ output_type: Optional[str] = "pil",
712
+ return_dict: bool = True,
713
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
714
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
715
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
716
+ max_sequence_length: int = 512,
717
+ guidance_scale_real: float = 1.0,
718
+ negative_prompt: Union[str, List[str]] = "",
719
+ negative_prompt_2: Union[str, List[str]] = "",
720
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
721
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
722
+ no_cfg_until_timestep: int = 0,
723
+ do_batch_cfg: bool=True,
724
+ ip_adapter_image: PipelineImageInput=None,
725
+ ip_adapter_scale: float=1.0,
726
+ device=torch.device('cuda'), # TODO let this work with non-cuda stuff? Might if you set this to None
727
+ ):
728
+ r"""
729
+ Function invoked when calling the pipeline for generation.
730
+ Args:
731
+ prompt (`str` or `List[str]`, *optional*):
732
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
733
+ instead.
734
+ prompt_mask (`str` or `List[str]`, *optional*):
735
+ The prompt or prompts to be used as a mask for the image generation. If not defined, `prompt` is used
736
+ instead.
737
+ prompt_2 (`str` or `List[str]`, *optional*):
738
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
739
+ will be used instead
740
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
741
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
742
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
743
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
744
+ num_inference_steps (`int`, *optional*, defaults to 50):
745
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
746
+ expense of slower inference.
747
+ timesteps (`List[int]`, *optional*):
748
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
749
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
750
+ passed will be used. Must be in descending order.
751
+ guidance_scale (`float`, *optional*, defaults to 7.0):
752
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
753
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
754
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
755
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
756
+ usually at the expense of lower image quality.
757
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
758
+ The number of images to generate per prompt.
759
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
760
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
761
+ to make generation deterministic.
762
+ latents (`torch.FloatTensor`, *optional*):
763
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
764
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
765
+ tensor will ge generated by sampling using the supplied random `generator`.
766
+ prompt_embeds (`torch.FloatTensor`, *optional*):
767
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
768
+ provided, text embeddings will be generated from `prompt` input argument.
769
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
770
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
771
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
772
+ output_type (`str`, *optional*, defaults to `"pil"`):
773
+ The output format of the generate image. Choose between
774
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
775
+ return_dict (`bool`, *optional*, defaults to `True`):
776
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
777
+ joint_attention_kwargs (`dict`, *optional*):
778
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
779
+ `self.processor` in
780
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
781
+ callback_on_step_end (`Callable`, *optional*):
782
+ A function that calls at the end of each denoising steps during the inference. The function is called
783
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
784
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
785
+ `callback_on_step_end_tensor_inputs`.
786
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
787
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
788
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
789
+ `._callback_tensor_inputs` attribute of your pipeline class.
790
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
791
+ Examples:
792
+ Returns:
793
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
794
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
795
+ images.
796
+ """
797
+
798
+ height = height or self.default_sample_size * self.vae_scale_factor
799
+ width = width or self.default_sample_size * self.vae_scale_factor
800
+
801
+ # 1. Check inputs. Raise error if not correct
802
+ self.check_inputs(
803
+ prompt,
804
+ prompt_2,
805
+ height,
806
+ width,
807
+ prompt_embeds=prompt_embeds,
808
+ pooled_prompt_embeds=pooled_prompt_embeds,
809
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
810
+ max_sequence_length=max_sequence_length,
811
+ )
812
+
813
+ # guidance_scale_real is redundant because this pipeline was originally
814
+ # made to be backwards compatible, but to make it the default just set
815
+ # guidance scale to be the same things.
816
+ guidance_scale_real = guidance_scale
817
+
818
+ self._guidance_scale = guidance_scale
819
+ self._guidance_scale_real = guidance_scale_real
820
+ self._joint_attention_kwargs = joint_attention_kwargs
821
+ self._interrupt = False
822
+
823
+ # 2. Define call parameters
824
+ if prompt is not None and isinstance(prompt, str):
825
+ batch_size = 1
826
+ elif prompt is not None and isinstance(prompt, list):
827
+ batch_size = len(prompt)
828
+ else:
829
+ batch_size = prompt_embeds.shape[0]
830
+
831
+ device = device or self._execution_device
832
+ dtype = self.transformer.dtype
833
+
834
+ lora_scale = (
835
+ self.joint_attention_kwargs.get("scale", None)
836
+ if self.joint_attention_kwargs is not None
837
+ else None
838
+ )
839
+ (
840
+ prompt_embeds,
841
+ pooled_prompt_embeds,
842
+ text_ids,
843
+ _prompt_mask,
844
+ ) = self.encode_prompt(
845
+ prompt=prompt,
846
+ prompt_2=prompt_2,
847
+ prompt_embeds=prompt_embeds,
848
+ pooled_prompt_embeds=pooled_prompt_embeds,
849
+ device=device,
850
+ num_images_per_prompt=num_images_per_prompt,
851
+ max_sequence_length=max_sequence_length,
852
+ lora_scale=lora_scale,
853
+ )
854
+ if _prompt_mask is not None:
855
+ prompt_mask = _prompt_mask
856
+ assert prompt_mask is not None
857
+
858
+ if negative_prompt_2 == "" and negative_prompt != "":
859
+ negative_prompt_2 = negative_prompt
860
+
861
+ negative_text_ids = text_ids
862
+ if self._guidance_scale_real > 1.0 and (
863
+ negative_prompt_embeds is None or negative_pooled_prompt_embeds is None
864
+ ):
865
+ (
866
+ negative_prompt_embeds,
867
+ negative_pooled_prompt_embeds,
868
+ negative_text_ids,
869
+ _neg_prompt_mask,
870
+ ) = self.encode_prompt(
871
+ prompt=negative_prompt,
872
+ prompt_2=negative_prompt_2,
873
+ prompt_embeds=None,
874
+ pooled_prompt_embeds=None,
875
+ device=device,
876
+ num_images_per_prompt=num_images_per_prompt,
877
+ max_sequence_length=max_sequence_length,
878
+ lora_scale=lora_scale,
879
+ )
880
+
881
+ if _neg_prompt_mask is not None:
882
+ negative_mask = _neg_prompt_mask
883
+
884
+ assert negative_mask is not None
885
+
886
+
887
+ ##################################
888
+ # CONTROL NET - VARIALBE PREP
889
+ ##################################
890
+ if control_image != None:
891
+ # 3. Prepare control image
892
+ num_channels_latents = self.transformer.config.in_channels // 4
893
+
894
+
895
+ inner_module = self.controlnet
896
+
897
+ control_image = self.prepare_image(
898
+ image=control_image,
899
+ width=width,
900
+ height=height,
901
+ batch_size=batch_size * num_images_per_prompt,
902
+ num_images_per_prompt=num_images_per_prompt,
903
+ device=device,
904
+ dtype=dtype,
905
+ )
906
+
907
+ if control_image_undo_centering:
908
+ if not self.image_processor.do_normalize:
909
+ raise ValueError(
910
+ "`control_image_undo_centering` only makes sense if `do_normalize==True` in the image processor"
911
+ )
912
+ control_image = control_image*0.5 + 0.5
913
+
914
+ height, width = control_image.shape[-2:]
915
+
916
+ # vae encode
917
+ control_image = self.vae.encode(control_image).latent_dist.sample()
918
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
919
+ # pack
920
+ height_control_image, width_control_image = control_image.shape[2:]
921
+ control_image = self._pack_latents(
922
+ control_image,
923
+ batch_size * num_images_per_prompt,
924
+ num_channels_latents,
925
+ height_control_image,
926
+ width_control_image,
927
+ )
928
+
929
+ # set control mode
930
+ if control_mode is not None:
931
+ control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
932
+ control_mode = control_mode.reshape([-1, 1])
933
+
934
+
935
+ # set control mode
936
+ control_mode_ = []
937
+ if isinstance(control_mode, list):
938
+ for cmode in control_mode:
939
+ if cmode is None:
940
+ control_mode_.append(-1)
941
+ else:
942
+ control_mode_.append(cmode)
943
+ control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
944
+ control_mode = control_mode.reshape([-1, 1])
945
+ else:
946
+ control_image = None
947
+ control_mode = None
948
+ ##################################
949
+ # END CONTROL NET - VARIALBE PREP
950
+ ##################################
951
+
952
+ # 4. Prepare latent variables
953
+ num_channels_latents = self.transformer.config.in_channels // 4
954
+ latents, latent_image_ids = self.prepare_latents(
955
+ batch_size * num_images_per_prompt,
956
+ num_channels_latents,
957
+ height,
958
+ width,
959
+ prompt_embeds.dtype,
960
+ device,
961
+ generator,
962
+ latents,
963
+ )
964
+
965
+ # 5. Prepare timesteps
966
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
967
+ image_seq_len = latents.shape[1]
968
+ mu = calculate_shift(
969
+ image_seq_len,
970
+ self.scheduler.config.base_image_seq_len,
971
+ self.scheduler.config.max_image_seq_len,
972
+ self.scheduler.config.base_shift,
973
+ self.scheduler.config.max_shift,
974
+ )
975
+ timesteps, num_inference_steps = retrieve_timesteps(
976
+ self.scheduler,
977
+ num_inference_steps,
978
+ device,
979
+ timesteps,
980
+ sigmas,
981
+ mu=mu,
982
+ )
983
+ num_warmup_steps = max(
984
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
985
+ )
986
+ self._num_timesteps = len(timesteps)
987
+
988
+ latents = latents
989
+ latent_image_ids = latent_image_ids
990
+ timesteps = timesteps
991
+ text_ids = text_ids.to(device=device)
992
+
993
+ # Denoising loop
994
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
995
+ for i, t in enumerate(timesteps):
996
+ if self.interrupt:
997
+ continue
998
+
999
+ # Prepare the latent model input
1000
+ prompt_embeds_input = prompt_embeds
1001
+ pooled_prompt_embeds_input = pooled_prompt_embeds
1002
+ text_ids_input = text_ids
1003
+ latent_image_ids_input = latent_image_ids
1004
+ prompt_mask_input = prompt_mask
1005
+ latent_model_input = latents
1006
+
1007
+ if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
1008
+ progress_bar.set_postfix(
1009
+ {
1010
+ 'ts': t.detach().item() / 1000.0,
1011
+ 'cfg': self._guidance_scale_real,
1012
+ },
1013
+ )
1014
+ else:
1015
+ progress_bar.set_postfix(
1016
+ {
1017
+ 'ts': t.detach().item() / 1000.0,
1018
+ 'cfg': 'N/A',
1019
+ },
1020
+ )
1021
+
1022
+ # Forward pass through the transformer
1023
+ with torch.no_grad():
1024
+ if ip_adapter_image != None and self.ip_loaded:
1025
+
1026
+ clip_image = self.clip_image_processor(images=ip_adapter_image,
1027
+ return_tensors="pt").pixel_values
1028
+ clip_image = clip_image.to(device=self.image_encoder.device,
1029
+ dtype=self.image_encoder.dtype)
1030
+ image_embeds = self.image_encoder(clip_image).pooler_output
1031
+ image_embeds_input = image_embeds
1032
+ else:
1033
+ image_embeds = None
1034
+ image_embeds_input = None
1035
+
1036
+ layer_scale = torch.Tensor([ip_adapter_scale])
1037
+ layer_scale_input = layer_scale
1038
+ neg_layer_scale = torch.Tensor([0.0])
1039
+ current_control_image = control_image
1040
+
1041
+
1042
+
1043
+ if do_batch_cfg and guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
1044
+ # Concatenate prompt embeddings
1045
+ prompt_embeds_input = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1046
+ pooled_prompt_embeds_input = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1047
+
1048
+ image_embeds_input = None
1049
+ if image_embeds != None:
1050
+ image_embeds_input = torch.cat([image_embeds]*2, dim=0)
1051
+
1052
+ layer_scale_input = torch.cat([neg_layer_scale , layer_scale], dim=0)
1053
+ # Concatenate text IDs if they are used
1054
+ # if text_ids is not None and negative_text_ids is not None:
1055
+ # text_ids_input = torch.cat([negative_text_ids, text_ids], dim=0)
1056
+
1057
+ # Concatenate latent image IDs if they are used
1058
+ # if latent_image_ids is not None:
1059
+ # latent_image_ids_input = torch.cat([latent_image_ids, latent_image_ids], dim=0)
1060
+
1061
+ # Concatenate prompt masks if they are used
1062
+ if prompt_mask is not None and negative_mask is not None:
1063
+ prompt_mask_input = torch.cat([negative_mask, prompt_mask], dim=0)
1064
+ # Duplicate latents for unconditional and conditional inputs
1065
+ latent_model_input = torch.cat([latents] * 2)
1066
+ if control_image != None:
1067
+ current_control_image = torch.cat([control_image] * 2)
1068
+ else:
1069
+ current_control_image = None
1070
+
1071
+ # Expand timestep to match batch size
1072
+ timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
1073
+
1074
+ guidance = None
1075
+
1076
+ div_timestep = (timestep / 1000.0)
1077
+ text_ids = [ t for t in text_ids ]
1078
+
1079
+ ######################################
1080
+ # ADD CONTROLNET - FORWARD
1081
+ ######################################
1082
+ if control_image != None:
1083
+ controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
1084
+ hidden_states=latent_model_input,
1085
+ controlnet_cond=current_control_image,
1086
+ controlnet_mode=control_mode,
1087
+ conditioning_scale=controlnet_conditioning_scale,
1088
+ timestep=div_timestep,
1089
+ guidance=None,
1090
+ pooled_projections=pooled_prompt_embeds_input,
1091
+ encoder_hidden_states=prompt_embeds_input,
1092
+ attention_mask=prompt_mask_input,
1093
+ txt_ids=text_ids_input[0],
1094
+ img_ids=latent_image_ids_input[0],
1095
+ joint_attention_kwargs=self.joint_attention_kwargs,
1096
+ return_dict=False
1097
+ )
1098
+ else:
1099
+ controlnet_block_samples = None
1100
+ controlnet_single_block_samples = None
1101
+ ######################################
1102
+ # END - ADD CONTROLNET - FORWARD
1103
+ ######################################
1104
+
1105
+
1106
+ noise_pred = self.ip_adapter(
1107
+ image_embeds_input,
1108
+ latent_model_input.to(device=self.transformer.device),
1109
+ layer_scale=layer_scale_input,
1110
+ timestep=div_timestep.to(device=self.transformer.device),
1111
+ guidance=None,
1112
+ pooled_projections=pooled_prompt_embeds_input.to(device=self.transformer.device),
1113
+ encoder_hidden_states=prompt_embeds_input.to(device=self.transformer.device),
1114
+ attention_mask=prompt_mask_input.to(device=self.transformer.device),
1115
+ controlnet_block_samples=controlnet_block_samples, ### A CONTROL NET INPUT
1116
+ controlnet_single_block_samples=controlnet_single_block_samples, ### A CONTROL NET INPUT
1117
+ txt_ids=text_ids_input[0],
1118
+ img_ids=latent_image_ids_input[0].to(device=self.transformer.device),
1119
+ return_dict=False,
1120
+ )[0]
1121
+
1122
+ # Apply real CFG
1123
+ if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
1124
+ if do_batch_cfg:
1125
+ # Batched CFG: Split the noise prediction into unconditional and conditional parts
1126
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
1127
+ noise_pred = noise_pred_uncond + guidance_scale_real * (noise_pred_cond - noise_pred_uncond)
1128
+ else:
1129
+ # Sequential CFG: Compute unconditional noise prediction separately
1130
+ if control_image != None:
1131
+ controlnet_block_samples_uncond, controlnet_single_block_samples_uncond = self.controlnet(
1132
+ hidden_states=latent_model_input,
1133
+ controlnet_cond=control_image,
1134
+ controlnet_mode=control_mode,
1135
+ conditioning_scale=controlnet_conditioning_scale,
1136
+ timestep=div_timestep,
1137
+ guidance=None,
1138
+ pooled_projections=negative_pooled_prompt_embeds.to(device=self.transformer.device),
1139
+ encoder_hidden_states=negative_prompt_embeds.to(device=self.transformer.device),
1140
+ attention_mask=negative_mask,
1141
+ txt_ids=negative_text_ids.to(device=self.transformer.device) if negative_text_ids is not None else None,
1142
+ img_ids=latent_image_ids[0].to(device=self.transformer.device),
1143
+ joint_attention_kwargs=self.joint_attention_kwargs,
1144
+ return_dict=False
1145
+ )
1146
+ else:
1147
+ controlnet_block_samples_uncond = None
1148
+ controlnet_single_block_samples_uncond = None
1149
+
1150
+ noise_pred_uncond = self.ip_adapter(
1151
+ image_embeds,
1152
+ latents.to(device=self.transformer.device),
1153
+ layer_scale=neg_layer_scale,
1154
+ timestep=div_timestep,
1155
+ guidance=None,
1156
+ pooled_projections=negative_pooled_prompt_embeds.to(device=self.transformer.device),
1157
+ encoder_hidden_states=negative_prompt_embeds.to(device=self.transformer.device),
1158
+ attention_mask=negative_mask,
1159
+ controlnet_block_samples=controlnet_block_samples_uncond, ### A CONTROL NET INPUT
1160
+ controlnet_single_block_samples=controlnet_single_block_samples_uncond, ### A CONTROL NET INPUT
1161
+ txt_ids=negative_text_ids.to(device=self.transformer.device) if negative_text_ids is not None else None,
1162
+ img_ids=latent_image_ids[0].to(device=self.transformer.device),
1163
+ return_dict=False,
1164
+ )[0]
1165
+
1166
+ # Combine conditional and unconditional predictions
1167
+ noise_pred = noise_pred_uncond + guidance_scale_real * (noise_pred - noise_pred_uncond)
1168
+
1169
+ # Compute the previous noisy sample x_t -> x_t-1
1170
+ latents_dtype = latents.dtype
1171
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1172
+
1173
+ # Ensure latents have the correct dtype
1174
+ if latents.dtype != latents_dtype:
1175
+ if torch.backends.mps.is_available():
1176
+ latents = latents.to(latents_dtype)
1177
+
1178
+ # Callback at the end of the step, if provided
1179
+ if callback_on_step_end is not None:
1180
+ callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs}
1181
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1182
+ latents = callback_outputs.get("latents", latents)
1183
+ prompt_embeds = callback_outputs.get("prompt_embeds", prompt_embeds)
1184
+
1185
+ # Update the progress bar
1186
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1187
+ progress_bar.update()
1188
+
1189
+ # Mark step for XLA devices
1190
+ if XLA_AVAILABLE:
1191
+ xm.mark_step()
1192
+
1193
+ if output_type == "latent":
1194
+ image = latents
1195
+
1196
+ else:
1197
+ latents = self._unpack_latents(
1198
+ latents, height, width, self.vae_scale_factor
1199
+ )
1200
+ latents = (
1201
+ latents / self.vae.config.scaling_factor
1202
+ ) + self.vae.config.shift_factor
1203
+
1204
+ latents = latents.to(dtype=self.vae.dtype)
1205
+
1206
+ image = self.vae.decode(
1207
+ latents,
1208
+ return_dict=False,
1209
+ )[0]
1210
+ image = self.image_processor.postprocess(image, output_type=output_type)
1211
+
1212
+ # Offload all models
1213
+ self.maybe_free_model_hooks()
1214
+
1215
+ if not return_dict:
1216
+ return (image,)
1217
+
1218
+ return FluxPipelineOutput(images=image)
1219
+
1220
+ def load_ip_adapter(self, checkpoint_path):
1221
+ """ Init model and load weights, or just load weights"""
1222
+
1223
+ self.ip_adapter.load_from_checkpoint(checkpoint_path)
1224
+ self.ip_adapter.to(self.transformer.device,dtype=self.dtype)
1225
+ self.ip_loaded = True
1226
+
1227
+ return self
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.35.2",
4
+ "base_image_seq_len": 256,
5
+ "base_shift": 0.5,
6
+ "max_image_seq_len": 4096,
7
+ "max_shift": 1.15,
8
+ "num_train_timesteps": 1000,
9
+ "shift": 1.0,
10
+ "use_dynamic_shifting": false
11
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "openai/clip-vit-large-patch14",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "quick_gelu",
11
+ "hidden_size": 768,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 768,
22
+ "torch_dtype": "bfloat16",
23
+ "transformers_version": "4.43.3",
24
+ "vocab_size": 49408
25
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:893d67a23f4693ed42cdab4cbad7fe3e727cf59609c40da28a46b5470f9ed082
3
+ size 246144352
text_encoder_2/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "google/t5-v1_1-xxl",
3
+ "architectures": [
4
+ "T5EncoderModel"
5
+ ],
6
+ "classifier_dropout": 0.0,
7
+ "d_ff": 10240,
8
+ "d_kv": 64,
9
+ "d_model": 4096,
10
+ "decoder_start_token_id": 0,
11
+ "dense_act_fn": "gelu_new",
12
+ "dropout_rate": 0.1,
13
+ "eos_token_id": 1,
14
+ "feed_forward_proj": "gated-gelu",
15
+ "initializer_factor": 1.0,
16
+ "is_encoder_decoder": true,
17
+ "is_gated_act": true,
18
+ "layer_norm_epsilon": 1e-06,
19
+ "model_type": "t5",
20
+ "num_decoder_layers": 24,
21
+ "num_heads": 64,
22
+ "num_layers": 24,
23
+ "output_past": true,
24
+ "pad_token_id": 0,
25
+ "relative_attention_max_distance": 128,
26
+ "relative_attention_num_buckets": 32,
27
+ "tie_word_embeddings": false,
28
+ "torch_dtype": "bfloat16",
29
+ "transformers_version": "4.43.3",
30
+ "use_cache": true,
31
+ "vocab_size": 32128
32
+ }
text_encoder_2/model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec87bffd1923e8b2774a6d240c922a41f6143081d52cf83b8fe39e9d838c893e
3
+ size 4994582224
text_encoder_2/model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5640855b301fcdbceddfa90ae8066cd9414aff020552a201a255ecf2059da00
3
+ size 4530066360
text_encoder_2/model.safetensors.index.json ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 9524621312
4
+ },
5
+ "weight_map": {
6
+ "encoder.block.0.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
7
+ "encoder.block.0.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
8
+ "encoder.block.0.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
9
+ "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight": "model-00001-of-00002.safetensors",
10
+ "encoder.block.0.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
11
+ "encoder.block.0.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
12
+ "encoder.block.0.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
13
+ "encoder.block.0.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
14
+ "encoder.block.0.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
15
+ "encoder.block.0.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
16
+ "encoder.block.1.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
17
+ "encoder.block.1.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
18
+ "encoder.block.1.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
19
+ "encoder.block.1.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
20
+ "encoder.block.1.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
21
+ "encoder.block.1.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
22
+ "encoder.block.1.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
23
+ "encoder.block.1.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
24
+ "encoder.block.1.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
25
+ "encoder.block.10.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
26
+ "encoder.block.10.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
27
+ "encoder.block.10.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
28
+ "encoder.block.10.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
29
+ "encoder.block.10.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
30
+ "encoder.block.10.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
31
+ "encoder.block.10.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
32
+ "encoder.block.10.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
33
+ "encoder.block.10.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
34
+ "encoder.block.11.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
35
+ "encoder.block.11.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
36
+ "encoder.block.11.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
37
+ "encoder.block.11.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
38
+ "encoder.block.11.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
39
+ "encoder.block.11.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
40
+ "encoder.block.11.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
41
+ "encoder.block.11.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
42
+ "encoder.block.11.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
43
+ "encoder.block.12.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
44
+ "encoder.block.12.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
45
+ "encoder.block.12.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
46
+ "encoder.block.12.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
47
+ "encoder.block.12.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
48
+ "encoder.block.12.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
49
+ "encoder.block.12.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
50
+ "encoder.block.12.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
51
+ "encoder.block.12.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
52
+ "encoder.block.13.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
53
+ "encoder.block.13.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
54
+ "encoder.block.13.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
55
+ "encoder.block.13.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
56
+ "encoder.block.13.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
57
+ "encoder.block.13.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
58
+ "encoder.block.13.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
59
+ "encoder.block.13.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
60
+ "encoder.block.13.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
61
+ "encoder.block.14.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
62
+ "encoder.block.14.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
63
+ "encoder.block.14.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
64
+ "encoder.block.14.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
65
+ "encoder.block.14.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
66
+ "encoder.block.14.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
67
+ "encoder.block.14.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
68
+ "encoder.block.14.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
69
+ "encoder.block.14.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
70
+ "encoder.block.15.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
71
+ "encoder.block.15.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
72
+ "encoder.block.15.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
73
+ "encoder.block.15.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
74
+ "encoder.block.15.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
75
+ "encoder.block.15.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
76
+ "encoder.block.15.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
77
+ "encoder.block.15.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
78
+ "encoder.block.15.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
79
+ "encoder.block.16.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
80
+ "encoder.block.16.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
81
+ "encoder.block.16.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
82
+ "encoder.block.16.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
83
+ "encoder.block.16.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
84
+ "encoder.block.16.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
85
+ "encoder.block.16.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
86
+ "encoder.block.16.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
87
+ "encoder.block.16.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
88
+ "encoder.block.17.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
89
+ "encoder.block.17.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
90
+ "encoder.block.17.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
91
+ "encoder.block.17.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
92
+ "encoder.block.17.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
93
+ "encoder.block.17.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
94
+ "encoder.block.17.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
95
+ "encoder.block.17.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
96
+ "encoder.block.17.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
97
+ "encoder.block.18.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
98
+ "encoder.block.18.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
99
+ "encoder.block.18.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
100
+ "encoder.block.18.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
101
+ "encoder.block.18.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
102
+ "encoder.block.18.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
103
+ "encoder.block.18.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
104
+ "encoder.block.18.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
105
+ "encoder.block.18.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
106
+ "encoder.block.19.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
107
+ "encoder.block.19.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
108
+ "encoder.block.19.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
109
+ "encoder.block.19.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
110
+ "encoder.block.19.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
111
+ "encoder.block.19.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
112
+ "encoder.block.19.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
113
+ "encoder.block.19.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
114
+ "encoder.block.19.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
115
+ "encoder.block.2.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
116
+ "encoder.block.2.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
117
+ "encoder.block.2.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
118
+ "encoder.block.2.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
119
+ "encoder.block.2.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
120
+ "encoder.block.2.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
121
+ "encoder.block.2.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
122
+ "encoder.block.2.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
123
+ "encoder.block.2.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
124
+ "encoder.block.20.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
125
+ "encoder.block.20.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
126
+ "encoder.block.20.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
127
+ "encoder.block.20.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
128
+ "encoder.block.20.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
129
+ "encoder.block.20.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
130
+ "encoder.block.20.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
131
+ "encoder.block.20.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
132
+ "encoder.block.20.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
133
+ "encoder.block.21.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
134
+ "encoder.block.21.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
135
+ "encoder.block.21.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
136
+ "encoder.block.21.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
137
+ "encoder.block.21.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
138
+ "encoder.block.21.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
139
+ "encoder.block.21.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
140
+ "encoder.block.21.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
141
+ "encoder.block.21.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
142
+ "encoder.block.22.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
143
+ "encoder.block.22.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
144
+ "encoder.block.22.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
145
+ "encoder.block.22.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
146
+ "encoder.block.22.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
147
+ "encoder.block.22.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
148
+ "encoder.block.22.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
149
+ "encoder.block.22.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
150
+ "encoder.block.22.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
151
+ "encoder.block.23.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
152
+ "encoder.block.23.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
153
+ "encoder.block.23.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
154
+ "encoder.block.23.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
155
+ "encoder.block.23.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
156
+ "encoder.block.23.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
157
+ "encoder.block.23.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
158
+ "encoder.block.23.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
159
+ "encoder.block.23.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
160
+ "encoder.block.3.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
161
+ "encoder.block.3.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
162
+ "encoder.block.3.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
163
+ "encoder.block.3.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
164
+ "encoder.block.3.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
165
+ "encoder.block.3.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
166
+ "encoder.block.3.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
167
+ "encoder.block.3.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
168
+ "encoder.block.3.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
169
+ "encoder.block.4.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
170
+ "encoder.block.4.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
171
+ "encoder.block.4.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
172
+ "encoder.block.4.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
173
+ "encoder.block.4.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
174
+ "encoder.block.4.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
175
+ "encoder.block.4.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
176
+ "encoder.block.4.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
177
+ "encoder.block.4.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
178
+ "encoder.block.5.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
179
+ "encoder.block.5.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
180
+ "encoder.block.5.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
181
+ "encoder.block.5.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
182
+ "encoder.block.5.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
183
+ "encoder.block.5.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
184
+ "encoder.block.5.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
185
+ "encoder.block.5.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
186
+ "encoder.block.5.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
187
+ "encoder.block.6.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
188
+ "encoder.block.6.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
189
+ "encoder.block.6.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
190
+ "encoder.block.6.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
191
+ "encoder.block.6.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
192
+ "encoder.block.6.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
193
+ "encoder.block.6.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
194
+ "encoder.block.6.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
195
+ "encoder.block.6.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
196
+ "encoder.block.7.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
197
+ "encoder.block.7.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
198
+ "encoder.block.7.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
199
+ "encoder.block.7.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
200
+ "encoder.block.7.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
201
+ "encoder.block.7.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
202
+ "encoder.block.7.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
203
+ "encoder.block.7.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
204
+ "encoder.block.7.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
205
+ "encoder.block.8.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
206
+ "encoder.block.8.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
207
+ "encoder.block.8.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
208
+ "encoder.block.8.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
209
+ "encoder.block.8.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
210
+ "encoder.block.8.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
211
+ "encoder.block.8.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
212
+ "encoder.block.8.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
213
+ "encoder.block.8.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
214
+ "encoder.block.9.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
215
+ "encoder.block.9.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
216
+ "encoder.block.9.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
217
+ "encoder.block.9.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
218
+ "encoder.block.9.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
219
+ "encoder.block.9.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
220
+ "encoder.block.9.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
221
+ "encoder.block.9.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
222
+ "encoder.block.9.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
223
+ "encoder.final_layer_norm.weight": "model-00002-of-00002.safetensors",
224
+ "shared.weight": "model-00001-of-00002.safetensors"
225
+ }
226
+ }
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|endoftext|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "49406": {
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49407": {
13
+ "content": "<|endoftext|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ }
20
+ },
21
+ "bos_token": "<|startoftext|>",
22
+ "clean_up_tokenization_spaces": true,
23
+ "do_lower_case": true,
24
+ "eos_token": "<|endoftext|>",
25
+ "errors": "replace",
26
+ "model_max_length": 77,
27
+ "pad_token": "<|endoftext|>",
28
+ "tokenizer_class": "CLIPTokenizer",
29
+ "unk_token": "<|endoftext|>"
30
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_2/special_tokens_map.json ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>"
103
+ ],
104
+ "eos_token": {
105
+ "content": "</s>",
106
+ "lstrip": false,
107
+ "normalized": false,
108
+ "rstrip": false,
109
+ "single_word": false
110
+ },
111
+ "pad_token": {
112
+ "content": "<pad>",
113
+ "lstrip": false,
114
+ "normalized": false,
115
+ "rstrip": false,
116
+ "single_word": false
117
+ },
118
+ "unk_token": {
119
+ "content": "<unk>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false
124
+ }
125
+ }
tokenizer_2/spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656
tokenizer_2/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_2/tokenizer_config.json ADDED
@@ -0,0 +1,940 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": true,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<pad>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "</s>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "<unk>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "32000": {
29
+ "content": "<extra_id_99>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "32001": {
37
+ "content": "<extra_id_98>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "32002": {
45
+ "content": "<extra_id_97>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "32003": {
53
+ "content": "<extra_id_96>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "32004": {
61
+ "content": "<extra_id_95>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "32005": {
69
+ "content": "<extra_id_94>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "32006": {
77
+ "content": "<extra_id_93>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "32007": {
85
+ "content": "<extra_id_92>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "32008": {
93
+ "content": "<extra_id_91>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "32009": {
101
+ "content": "<extra_id_90>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "32010": {
109
+ "content": "<extra_id_89>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "32011": {
117
+ "content": "<extra_id_88>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "32012": {
125
+ "content": "<extra_id_87>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "32013": {
133
+ "content": "<extra_id_86>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "32014": {
141
+ "content": "<extra_id_85>",
142
+ "lstrip": false,
143
+ "normalized": false,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "32015": {
149
+ "content": "<extra_id_84>",
150
+ "lstrip": false,
151
+ "normalized": false,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": true
155
+ },
156
+ "32016": {
157
+ "content": "<extra_id_83>",
158
+ "lstrip": false,
159
+ "normalized": false,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": true
163
+ },
164
+ "32017": {
165
+ "content": "<extra_id_82>",
166
+ "lstrip": false,
167
+ "normalized": false,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": true
171
+ },
172
+ "32018": {
173
+ "content": "<extra_id_81>",
174
+ "lstrip": false,
175
+ "normalized": false,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": true
179
+ },
180
+ "32019": {
181
+ "content": "<extra_id_80>",
182
+ "lstrip": false,
183
+ "normalized": false,
184
+ "rstrip": false,
185
+ "single_word": false,
186
+ "special": true
187
+ },
188
+ "32020": {
189
+ "content": "<extra_id_79>",
190
+ "lstrip": false,
191
+ "normalized": false,
192
+ "rstrip": false,
193
+ "single_word": false,
194
+ "special": true
195
+ },
196
+ "32021": {
197
+ "content": "<extra_id_78>",
198
+ "lstrip": false,
199
+ "normalized": false,
200
+ "rstrip": false,
201
+ "single_word": false,
202
+ "special": true
203
+ },
204
+ "32022": {
205
+ "content": "<extra_id_77>",
206
+ "lstrip": false,
207
+ "normalized": false,
208
+ "rstrip": false,
209
+ "single_word": false,
210
+ "special": true
211
+ },
212
+ "32023": {
213
+ "content": "<extra_id_76>",
214
+ "lstrip": false,
215
+ "normalized": false,
216
+ "rstrip": false,
217
+ "single_word": false,
218
+ "special": true
219
+ },
220
+ "32024": {
221
+ "content": "<extra_id_75>",
222
+ "lstrip": false,
223
+ "normalized": false,
224
+ "rstrip": false,
225
+ "single_word": false,
226
+ "special": true
227
+ },
228
+ "32025": {
229
+ "content": "<extra_id_74>",
230
+ "lstrip": false,
231
+ "normalized": false,
232
+ "rstrip": false,
233
+ "single_word": false,
234
+ "special": true
235
+ },
236
+ "32026": {
237
+ "content": "<extra_id_73>",
238
+ "lstrip": false,
239
+ "normalized": false,
240
+ "rstrip": false,
241
+ "single_word": false,
242
+ "special": true
243
+ },
244
+ "32027": {
245
+ "content": "<extra_id_72>",
246
+ "lstrip": false,
247
+ "normalized": false,
248
+ "rstrip": false,
249
+ "single_word": false,
250
+ "special": true
251
+ },
252
+ "32028": {
253
+ "content": "<extra_id_71>",
254
+ "lstrip": false,
255
+ "normalized": false,
256
+ "rstrip": false,
257
+ "single_word": false,
258
+ "special": true
259
+ },
260
+ "32029": {
261
+ "content": "<extra_id_70>",
262
+ "lstrip": false,
263
+ "normalized": false,
264
+ "rstrip": false,
265
+ "single_word": false,
266
+ "special": true
267
+ },
268
+ "32030": {
269
+ "content": "<extra_id_69>",
270
+ "lstrip": false,
271
+ "normalized": false,
272
+ "rstrip": false,
273
+ "single_word": false,
274
+ "special": true
275
+ },
276
+ "32031": {
277
+ "content": "<extra_id_68>",
278
+ "lstrip": false,
279
+ "normalized": false,
280
+ "rstrip": false,
281
+ "single_word": false,
282
+ "special": true
283
+ },
284
+ "32032": {
285
+ "content": "<extra_id_67>",
286
+ "lstrip": false,
287
+ "normalized": false,
288
+ "rstrip": false,
289
+ "single_word": false,
290
+ "special": true
291
+ },
292
+ "32033": {
293
+ "content": "<extra_id_66>",
294
+ "lstrip": false,
295
+ "normalized": false,
296
+ "rstrip": false,
297
+ "single_word": false,
298
+ "special": true
299
+ },
300
+ "32034": {
301
+ "content": "<extra_id_65>",
302
+ "lstrip": false,
303
+ "normalized": false,
304
+ "rstrip": false,
305
+ "single_word": false,
306
+ "special": true
307
+ },
308
+ "32035": {
309
+ "content": "<extra_id_64>",
310
+ "lstrip": false,
311
+ "normalized": false,
312
+ "rstrip": false,
313
+ "single_word": false,
314
+ "special": true
315
+ },
316
+ "32036": {
317
+ "content": "<extra_id_63>",
318
+ "lstrip": false,
319
+ "normalized": false,
320
+ "rstrip": false,
321
+ "single_word": false,
322
+ "special": true
323
+ },
324
+ "32037": {
325
+ "content": "<extra_id_62>",
326
+ "lstrip": false,
327
+ "normalized": false,
328
+ "rstrip": false,
329
+ "single_word": false,
330
+ "special": true
331
+ },
332
+ "32038": {
333
+ "content": "<extra_id_61>",
334
+ "lstrip": false,
335
+ "normalized": false,
336
+ "rstrip": false,
337
+ "single_word": false,
338
+ "special": true
339
+ },
340
+ "32039": {
341
+ "content": "<extra_id_60>",
342
+ "lstrip": false,
343
+ "normalized": false,
344
+ "rstrip": false,
345
+ "single_word": false,
346
+ "special": true
347
+ },
348
+ "32040": {
349
+ "content": "<extra_id_59>",
350
+ "lstrip": false,
351
+ "normalized": false,
352
+ "rstrip": false,
353
+ "single_word": false,
354
+ "special": true
355
+ },
356
+ "32041": {
357
+ "content": "<extra_id_58>",
358
+ "lstrip": false,
359
+ "normalized": false,
360
+ "rstrip": false,
361
+ "single_word": false,
362
+ "special": true
363
+ },
364
+ "32042": {
365
+ "content": "<extra_id_57>",
366
+ "lstrip": false,
367
+ "normalized": false,
368
+ "rstrip": false,
369
+ "single_word": false,
370
+ "special": true
371
+ },
372
+ "32043": {
373
+ "content": "<extra_id_56>",
374
+ "lstrip": false,
375
+ "normalized": false,
376
+ "rstrip": false,
377
+ "single_word": false,
378
+ "special": true
379
+ },
380
+ "32044": {
381
+ "content": "<extra_id_55>",
382
+ "lstrip": false,
383
+ "normalized": false,
384
+ "rstrip": false,
385
+ "single_word": false,
386
+ "special": true
387
+ },
388
+ "32045": {
389
+ "content": "<extra_id_54>",
390
+ "lstrip": false,
391
+ "normalized": false,
392
+ "rstrip": false,
393
+ "single_word": false,
394
+ "special": true
395
+ },
396
+ "32046": {
397
+ "content": "<extra_id_53>",
398
+ "lstrip": false,
399
+ "normalized": false,
400
+ "rstrip": false,
401
+ "single_word": false,
402
+ "special": true
403
+ },
404
+ "32047": {
405
+ "content": "<extra_id_52>",
406
+ "lstrip": false,
407
+ "normalized": false,
408
+ "rstrip": false,
409
+ "single_word": false,
410
+ "special": true
411
+ },
412
+ "32048": {
413
+ "content": "<extra_id_51>",
414
+ "lstrip": false,
415
+ "normalized": false,
416
+ "rstrip": false,
417
+ "single_word": false,
418
+ "special": true
419
+ },
420
+ "32049": {
421
+ "content": "<extra_id_50>",
422
+ "lstrip": false,
423
+ "normalized": false,
424
+ "rstrip": false,
425
+ "single_word": false,
426
+ "special": true
427
+ },
428
+ "32050": {
429
+ "content": "<extra_id_49>",
430
+ "lstrip": false,
431
+ "normalized": false,
432
+ "rstrip": false,
433
+ "single_word": false,
434
+ "special": true
435
+ },
436
+ "32051": {
437
+ "content": "<extra_id_48>",
438
+ "lstrip": false,
439
+ "normalized": false,
440
+ "rstrip": false,
441
+ "single_word": false,
442
+ "special": true
443
+ },
444
+ "32052": {
445
+ "content": "<extra_id_47>",
446
+ "lstrip": false,
447
+ "normalized": false,
448
+ "rstrip": false,
449
+ "single_word": false,
450
+ "special": true
451
+ },
452
+ "32053": {
453
+ "content": "<extra_id_46>",
454
+ "lstrip": false,
455
+ "normalized": false,
456
+ "rstrip": false,
457
+ "single_word": false,
458
+ "special": true
459
+ },
460
+ "32054": {
461
+ "content": "<extra_id_45>",
462
+ "lstrip": false,
463
+ "normalized": false,
464
+ "rstrip": false,
465
+ "single_word": false,
466
+ "special": true
467
+ },
468
+ "32055": {
469
+ "content": "<extra_id_44>",
470
+ "lstrip": false,
471
+ "normalized": false,
472
+ "rstrip": false,
473
+ "single_word": false,
474
+ "special": true
475
+ },
476
+ "32056": {
477
+ "content": "<extra_id_43>",
478
+ "lstrip": false,
479
+ "normalized": false,
480
+ "rstrip": false,
481
+ "single_word": false,
482
+ "special": true
483
+ },
484
+ "32057": {
485
+ "content": "<extra_id_42>",
486
+ "lstrip": false,
487
+ "normalized": false,
488
+ "rstrip": false,
489
+ "single_word": false,
490
+ "special": true
491
+ },
492
+ "32058": {
493
+ "content": "<extra_id_41>",
494
+ "lstrip": false,
495
+ "normalized": false,
496
+ "rstrip": false,
497
+ "single_word": false,
498
+ "special": true
499
+ },
500
+ "32059": {
501
+ "content": "<extra_id_40>",
502
+ "lstrip": false,
503
+ "normalized": false,
504
+ "rstrip": false,
505
+ "single_word": false,
506
+ "special": true
507
+ },
508
+ "32060": {
509
+ "content": "<extra_id_39>",
510
+ "lstrip": false,
511
+ "normalized": false,
512
+ "rstrip": false,
513
+ "single_word": false,
514
+ "special": true
515
+ },
516
+ "32061": {
517
+ "content": "<extra_id_38>",
518
+ "lstrip": false,
519
+ "normalized": false,
520
+ "rstrip": false,
521
+ "single_word": false,
522
+ "special": true
523
+ },
524
+ "32062": {
525
+ "content": "<extra_id_37>",
526
+ "lstrip": false,
527
+ "normalized": false,
528
+ "rstrip": false,
529
+ "single_word": false,
530
+ "special": true
531
+ },
532
+ "32063": {
533
+ "content": "<extra_id_36>",
534
+ "lstrip": false,
535
+ "normalized": false,
536
+ "rstrip": false,
537
+ "single_word": false,
538
+ "special": true
539
+ },
540
+ "32064": {
541
+ "content": "<extra_id_35>",
542
+ "lstrip": false,
543
+ "normalized": false,
544
+ "rstrip": false,
545
+ "single_word": false,
546
+ "special": true
547
+ },
548
+ "32065": {
549
+ "content": "<extra_id_34>",
550
+ "lstrip": false,
551
+ "normalized": false,
552
+ "rstrip": false,
553
+ "single_word": false,
554
+ "special": true
555
+ },
556
+ "32066": {
557
+ "content": "<extra_id_33>",
558
+ "lstrip": false,
559
+ "normalized": false,
560
+ "rstrip": false,
561
+ "single_word": false,
562
+ "special": true
563
+ },
564
+ "32067": {
565
+ "content": "<extra_id_32>",
566
+ "lstrip": false,
567
+ "normalized": false,
568
+ "rstrip": false,
569
+ "single_word": false,
570
+ "special": true
571
+ },
572
+ "32068": {
573
+ "content": "<extra_id_31>",
574
+ "lstrip": false,
575
+ "normalized": false,
576
+ "rstrip": false,
577
+ "single_word": false,
578
+ "special": true
579
+ },
580
+ "32069": {
581
+ "content": "<extra_id_30>",
582
+ "lstrip": false,
583
+ "normalized": false,
584
+ "rstrip": false,
585
+ "single_word": false,
586
+ "special": true
587
+ },
588
+ "32070": {
589
+ "content": "<extra_id_29>",
590
+ "lstrip": false,
591
+ "normalized": false,
592
+ "rstrip": false,
593
+ "single_word": false,
594
+ "special": true
595
+ },
596
+ "32071": {
597
+ "content": "<extra_id_28>",
598
+ "lstrip": false,
599
+ "normalized": false,
600
+ "rstrip": false,
601
+ "single_word": false,
602
+ "special": true
603
+ },
604
+ "32072": {
605
+ "content": "<extra_id_27>",
606
+ "lstrip": false,
607
+ "normalized": false,
608
+ "rstrip": false,
609
+ "single_word": false,
610
+ "special": true
611
+ },
612
+ "32073": {
613
+ "content": "<extra_id_26>",
614
+ "lstrip": false,
615
+ "normalized": false,
616
+ "rstrip": false,
617
+ "single_word": false,
618
+ "special": true
619
+ },
620
+ "32074": {
621
+ "content": "<extra_id_25>",
622
+ "lstrip": false,
623
+ "normalized": false,
624
+ "rstrip": false,
625
+ "single_word": false,
626
+ "special": true
627
+ },
628
+ "32075": {
629
+ "content": "<extra_id_24>",
630
+ "lstrip": false,
631
+ "normalized": false,
632
+ "rstrip": false,
633
+ "single_word": false,
634
+ "special": true
635
+ },
636
+ "32076": {
637
+ "content": "<extra_id_23>",
638
+ "lstrip": false,
639
+ "normalized": false,
640
+ "rstrip": false,
641
+ "single_word": false,
642
+ "special": true
643
+ },
644
+ "32077": {
645
+ "content": "<extra_id_22>",
646
+ "lstrip": false,
647
+ "normalized": false,
648
+ "rstrip": false,
649
+ "single_word": false,
650
+ "special": true
651
+ },
652
+ "32078": {
653
+ "content": "<extra_id_21>",
654
+ "lstrip": false,
655
+ "normalized": false,
656
+ "rstrip": false,
657
+ "single_word": false,
658
+ "special": true
659
+ },
660
+ "32079": {
661
+ "content": "<extra_id_20>",
662
+ "lstrip": false,
663
+ "normalized": false,
664
+ "rstrip": false,
665
+ "single_word": false,
666
+ "special": true
667
+ },
668
+ "32080": {
669
+ "content": "<extra_id_19>",
670
+ "lstrip": false,
671
+ "normalized": false,
672
+ "rstrip": false,
673
+ "single_word": false,
674
+ "special": true
675
+ },
676
+ "32081": {
677
+ "content": "<extra_id_18>",
678
+ "lstrip": false,
679
+ "normalized": false,
680
+ "rstrip": false,
681
+ "single_word": false,
682
+ "special": true
683
+ },
684
+ "32082": {
685
+ "content": "<extra_id_17>",
686
+ "lstrip": false,
687
+ "normalized": false,
688
+ "rstrip": false,
689
+ "single_word": false,
690
+ "special": true
691
+ },
692
+ "32083": {
693
+ "content": "<extra_id_16>",
694
+ "lstrip": false,
695
+ "normalized": false,
696
+ "rstrip": false,
697
+ "single_word": false,
698
+ "special": true
699
+ },
700
+ "32084": {
701
+ "content": "<extra_id_15>",
702
+ "lstrip": false,
703
+ "normalized": false,
704
+ "rstrip": false,
705
+ "single_word": false,
706
+ "special": true
707
+ },
708
+ "32085": {
709
+ "content": "<extra_id_14>",
710
+ "lstrip": false,
711
+ "normalized": false,
712
+ "rstrip": false,
713
+ "single_word": false,
714
+ "special": true
715
+ },
716
+ "32086": {
717
+ "content": "<extra_id_13>",
718
+ "lstrip": false,
719
+ "normalized": false,
720
+ "rstrip": false,
721
+ "single_word": false,
722
+ "special": true
723
+ },
724
+ "32087": {
725
+ "content": "<extra_id_12>",
726
+ "lstrip": false,
727
+ "normalized": false,
728
+ "rstrip": false,
729
+ "single_word": false,
730
+ "special": true
731
+ },
732
+ "32088": {
733
+ "content": "<extra_id_11>",
734
+ "lstrip": false,
735
+ "normalized": false,
736
+ "rstrip": false,
737
+ "single_word": false,
738
+ "special": true
739
+ },
740
+ "32089": {
741
+ "content": "<extra_id_10>",
742
+ "lstrip": false,
743
+ "normalized": false,
744
+ "rstrip": false,
745
+ "single_word": false,
746
+ "special": true
747
+ },
748
+ "32090": {
749
+ "content": "<extra_id_9>",
750
+ "lstrip": false,
751
+ "normalized": false,
752
+ "rstrip": false,
753
+ "single_word": false,
754
+ "special": true
755
+ },
756
+ "32091": {
757
+ "content": "<extra_id_8>",
758
+ "lstrip": false,
759
+ "normalized": false,
760
+ "rstrip": false,
761
+ "single_word": false,
762
+ "special": true
763
+ },
764
+ "32092": {
765
+ "content": "<extra_id_7>",
766
+ "lstrip": false,
767
+ "normalized": false,
768
+ "rstrip": false,
769
+ "single_word": false,
770
+ "special": true
771
+ },
772
+ "32093": {
773
+ "content": "<extra_id_6>",
774
+ "lstrip": false,
775
+ "normalized": false,
776
+ "rstrip": false,
777
+ "single_word": false,
778
+ "special": true
779
+ },
780
+ "32094": {
781
+ "content": "<extra_id_5>",
782
+ "lstrip": false,
783
+ "normalized": false,
784
+ "rstrip": false,
785
+ "single_word": false,
786
+ "special": true
787
+ },
788
+ "32095": {
789
+ "content": "<extra_id_4>",
790
+ "lstrip": false,
791
+ "normalized": false,
792
+ "rstrip": false,
793
+ "single_word": false,
794
+ "special": true
795
+ },
796
+ "32096": {
797
+ "content": "<extra_id_3>",
798
+ "lstrip": false,
799
+ "normalized": false,
800
+ "rstrip": false,
801
+ "single_word": false,
802
+ "special": true
803
+ },
804
+ "32097": {
805
+ "content": "<extra_id_2>",
806
+ "lstrip": false,
807
+ "normalized": false,
808
+ "rstrip": false,
809
+ "single_word": false,
810
+ "special": true
811
+ },
812
+ "32098": {
813
+ "content": "<extra_id_1>",
814
+ "lstrip": false,
815
+ "normalized": false,
816
+ "rstrip": false,
817
+ "single_word": false,
818
+ "special": true
819
+ },
820
+ "32099": {
821
+ "content": "<extra_id_0>",
822
+ "lstrip": false,
823
+ "normalized": false,
824
+ "rstrip": false,
825
+ "single_word": false,
826
+ "special": true
827
+ }
828
+ },
829
+ "additional_special_tokens": [
830
+ "<extra_id_0>",
831
+ "<extra_id_1>",
832
+ "<extra_id_2>",
833
+ "<extra_id_3>",
834
+ "<extra_id_4>",
835
+ "<extra_id_5>",
836
+ "<extra_id_6>",
837
+ "<extra_id_7>",
838
+ "<extra_id_8>",
839
+ "<extra_id_9>",
840
+ "<extra_id_10>",
841
+ "<extra_id_11>",
842
+ "<extra_id_12>",
843
+ "<extra_id_13>",
844
+ "<extra_id_14>",
845
+ "<extra_id_15>",
846
+ "<extra_id_16>",
847
+ "<extra_id_17>",
848
+ "<extra_id_18>",
849
+ "<extra_id_19>",
850
+ "<extra_id_20>",
851
+ "<extra_id_21>",
852
+ "<extra_id_22>",
853
+ "<extra_id_23>",
854
+ "<extra_id_24>",
855
+ "<extra_id_25>",
856
+ "<extra_id_26>",
857
+ "<extra_id_27>",
858
+ "<extra_id_28>",
859
+ "<extra_id_29>",
860
+ "<extra_id_30>",
861
+ "<extra_id_31>",
862
+ "<extra_id_32>",
863
+ "<extra_id_33>",
864
+ "<extra_id_34>",
865
+ "<extra_id_35>",
866
+ "<extra_id_36>",
867
+ "<extra_id_37>",
868
+ "<extra_id_38>",
869
+ "<extra_id_39>",
870
+ "<extra_id_40>",
871
+ "<extra_id_41>",
872
+ "<extra_id_42>",
873
+ "<extra_id_43>",
874
+ "<extra_id_44>",
875
+ "<extra_id_45>",
876
+ "<extra_id_46>",
877
+ "<extra_id_47>",
878
+ "<extra_id_48>",
879
+ "<extra_id_49>",
880
+ "<extra_id_50>",
881
+ "<extra_id_51>",
882
+ "<extra_id_52>",
883
+ "<extra_id_53>",
884
+ "<extra_id_54>",
885
+ "<extra_id_55>",
886
+ "<extra_id_56>",
887
+ "<extra_id_57>",
888
+ "<extra_id_58>",
889
+ "<extra_id_59>",
890
+ "<extra_id_60>",
891
+ "<extra_id_61>",
892
+ "<extra_id_62>",
893
+ "<extra_id_63>",
894
+ "<extra_id_64>",
895
+ "<extra_id_65>",
896
+ "<extra_id_66>",
897
+ "<extra_id_67>",
898
+ "<extra_id_68>",
899
+ "<extra_id_69>",
900
+ "<extra_id_70>",
901
+ "<extra_id_71>",
902
+ "<extra_id_72>",
903
+ "<extra_id_73>",
904
+ "<extra_id_74>",
905
+ "<extra_id_75>",
906
+ "<extra_id_76>",
907
+ "<extra_id_77>",
908
+ "<extra_id_78>",
909
+ "<extra_id_79>",
910
+ "<extra_id_80>",
911
+ "<extra_id_81>",
912
+ "<extra_id_82>",
913
+ "<extra_id_83>",
914
+ "<extra_id_84>",
915
+ "<extra_id_85>",
916
+ "<extra_id_86>",
917
+ "<extra_id_87>",
918
+ "<extra_id_88>",
919
+ "<extra_id_89>",
920
+ "<extra_id_90>",
921
+ "<extra_id_91>",
922
+ "<extra_id_92>",
923
+ "<extra_id_93>",
924
+ "<extra_id_94>",
925
+ "<extra_id_95>",
926
+ "<extra_id_96>",
927
+ "<extra_id_97>",
928
+ "<extra_id_98>",
929
+ "<extra_id_99>"
930
+ ],
931
+ "clean_up_tokenization_spaces": true,
932
+ "eos_token": "</s>",
933
+ "extra_ids": 100,
934
+ "legacy": true,
935
+ "model_max_length": 512,
936
+ "pad_token": "<pad>",
937
+ "sp_model_kwargs": {},
938
+ "tokenizer_class": "T5Tokenizer",
939
+ "unk_token": "<unk>"
940
+ }
transformer/config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "LibreFluxTransformer2DModel",
3
+ "_diffusers_version": "0.35.2",
4
+ "_name_or_path": "/path/to/transformer",
5
+ "attention_head_dim": 128,
6
+ "axes_dims_rope": [
7
+ 16,
8
+ 56,
9
+ 56
10
+ ],
11
+ "guidance_embeds": false,
12
+ "in_channels": 64,
13
+ "joint_attention_dim": 4096,
14
+ "num_attention_heads": 24,
15
+ "num_layers": 19,
16
+ "num_single_layers": 38,
17
+ "patch_size": 1,
18
+ "pooled_projection_dim": 768
19
+ }
transformer/diffusion_pytorch_model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18c2abe01a326d95bc836cfd5f68167118c0ecb2c8ccbcf5d6de4dbad47ca53c
3
+ size 9962580296
transformer/diffusion_pytorch_model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:828f131e306b17535c8c1d0a3c4aaa06f2a60a80612500da229829242f3ed422
3
+ size 9949328904
transformer/diffusion_pytorch_model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a17988ef4255372dd902ff5742e647f8a60dcc83756d740d1fbcf81d13d38162
3
+ size 3870584832
transformer/diffusion_pytorch_model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
transformer/trans.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, The HuggingFace Team, The InstantX Team, and Terminus Research Group. All rights reserved.
2
+ #
3
+ # Originally licensed under the Apache License, Version 2.0 (the "License");
4
+ # Updated to "Affero GENERAL PUBLIC LICENSE Version 3, 19 November 2007" via extensive updates to attn_mask usage.
5
+ # This was taken from SimpleTuner and modified as needed
6
+
7
+ from typing import Any, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
16
+ from diffusers.models.attention import FeedForward
17
+ from diffusers.models.attention_processor import (
18
+ Attention,
19
+ AttentionProcessor,
20
+ )
21
+ from diffusers.models.modeling_utils import ModelMixin
22
+ from diffusers.models.normalization import (
23
+ AdaLayerNormContinuous,
24
+ AdaLayerNormZero,
25
+ AdaLayerNormZeroSingle,
26
+ )
27
+ from diffusers.utils import (
28
+ USE_PEFT_BACKEND,
29
+ is_torch_version,
30
+ logging,
31
+ scale_lora_layers,
32
+ unscale_lora_layers,
33
+ )
34
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
35
+ from diffusers.models.embeddings import (
36
+ CombinedTimestepGuidanceTextProjEmbeddings,
37
+ CombinedTimestepTextProjEmbeddings,
38
+ FluxPosEmbed,
39
+ )
40
+
41
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
42
+ from diffusers import FluxTransformer2DModel as OriginalFluxTransformer2DModel
43
+
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+
48
+
49
+ class FluxAttnProcessor2_0:
50
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
51
+
52
+ def __init__(self):
53
+ if not hasattr(F, "scaled_dot_product_attention"):
54
+ raise ImportError(
55
+ "FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
56
+ )
57
+
58
+ def __call__(
59
+ self,
60
+ attn: Attention,
61
+ hidden_states: torch.FloatTensor,
62
+ encoder_hidden_states: torch.FloatTensor = None,
63
+ attention_mask: Optional[torch.FloatTensor] = None,
64
+ image_rotary_emb: Optional[torch.Tensor] = None,
65
+ ip_encoder_hidden_states:Optional[torch.Tensor] = None,
66
+ layer_scale:Optional[torch.Tensor] = None,
67
+ ) -> torch.FloatTensor:
68
+ batch_size, _, _ = (
69
+ hidden_states.shape
70
+ if encoder_hidden_states is None
71
+ else encoder_hidden_states.shape
72
+ )
73
+
74
+ # `sample` projections.
75
+ query = attn.to_q(hidden_states)
76
+ key = attn.to_k(hidden_states)
77
+ value = attn.to_v(hidden_states)
78
+
79
+ inner_dim = key.shape[-1]
80
+ head_dim = inner_dim // attn.heads
81
+
82
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
83
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
84
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
85
+
86
+ if attn.norm_q is not None:
87
+ query = attn.norm_q(query)
88
+ if attn.norm_k is not None:
89
+ key = attn.norm_k(key)
90
+
91
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
92
+ if encoder_hidden_states is not None:
93
+ # `context` projections.
94
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
95
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
96
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
97
+
98
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
99
+ batch_size, -1, attn.heads, head_dim
100
+ ).transpose(1, 2)
101
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
102
+ batch_size, -1, attn.heads, head_dim
103
+ ).transpose(1, 2)
104
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
105
+ batch_size, -1, attn.heads, head_dim
106
+ ).transpose(1, 2)
107
+
108
+ if attn.norm_added_q is not None:
109
+ encoder_hidden_states_query_proj = attn.norm_added_q(
110
+ encoder_hidden_states_query_proj
111
+ )
112
+ if attn.norm_added_k is not None:
113
+ encoder_hidden_states_key_proj = attn.norm_added_k(
114
+ encoder_hidden_states_key_proj
115
+ )
116
+
117
+ # attention
118
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
119
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
120
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
121
+
122
+ if image_rotary_emb is not None:
123
+ from diffusers.models.embeddings import apply_rotary_emb
124
+
125
+ query = apply_rotary_emb(query, image_rotary_emb)
126
+ key = apply_rotary_emb(key, image_rotary_emb)
127
+
128
+ if attention_mask is not None:
129
+ #print ('Attention Used')
130
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
131
+ attention_mask = (attention_mask > 0).bool()
132
+ # Edit 17 - match attn dtype to query d-type
133
+ attention_mask = attention_mask.to(
134
+ device=hidden_states.device, dtype=query.dtype
135
+ )
136
+
137
+ hidden_states = F.scaled_dot_product_attention(
138
+ query,
139
+ key,
140
+ value,
141
+ dropout_p=0.0,
142
+ is_causal=False,
143
+ attn_mask=attention_mask,
144
+ )
145
+ hidden_states = hidden_states.transpose(1, 2).reshape(
146
+ batch_size, -1, attn.heads * head_dim
147
+ )
148
+ hidden_states = hidden_states.to(query.dtype)
149
+
150
+ if encoder_hidden_states is not None:
151
+ encoder_hidden_states, hidden_states = (
152
+ hidden_states[:, : encoder_hidden_states.shape[1]],
153
+ hidden_states[:, encoder_hidden_states.shape[1] :],
154
+ )
155
+
156
+ # linear proj
157
+ hidden_states = attn.to_out[0](hidden_states)
158
+ # dropout
159
+ hidden_states = attn.to_out[1](hidden_states)
160
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
161
+
162
+ return hidden_states, encoder_hidden_states
163
+ return hidden_states
164
+
165
+
166
+ def expand_flux_attention_mask(
167
+ hidden_states: torch.Tensor,
168
+ attn_mask: torch.Tensor,
169
+ ) -> torch.Tensor:
170
+ """
171
+ Expand a mask so that the image is included.
172
+ """
173
+ bsz = attn_mask.shape[0]
174
+ assert bsz == hidden_states.shape[0]
175
+ residual_seq_len = hidden_states.shape[1]
176
+ mask_seq_len = attn_mask.shape[1]
177
+
178
+ expanded_mask = torch.ones(bsz, residual_seq_len)
179
+ expanded_mask[:, :mask_seq_len] = attn_mask
180
+
181
+ return expanded_mask
182
+
183
+
184
+ @maybe_allow_in_graph
185
+ class FluxSingleTransformerBlock(nn.Module):
186
+ r"""
187
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
188
+
189
+ Reference: https://arxiv.org/abs/2403.03206
190
+
191
+ Parameters:
192
+ dim (`int`): The number of channels in the input and output.
193
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
194
+ attention_head_dim (`int`): The number of channels in each head.
195
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
196
+ processing of `context` conditions.
197
+ """
198
+
199
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
200
+ super().__init__()
201
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
202
+
203
+ self.norm = AdaLayerNormZeroSingle(dim)
204
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
205
+ self.act_mlp = nn.GELU(approximate="tanh")
206
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
207
+
208
+ processor = FluxAttnProcessor2_0()
209
+ self.attn = Attention(
210
+ query_dim=dim,
211
+ cross_attention_dim=None,
212
+ dim_head=attention_head_dim,
213
+ heads=num_attention_heads,
214
+ out_dim=dim,
215
+ bias=True,
216
+ processor=processor,
217
+ qk_norm="rms_norm",
218
+ eps=1e-6,
219
+ pre_only=True,
220
+ )
221
+
222
+ def forward(
223
+ self,
224
+ hidden_states: torch.FloatTensor,
225
+ temb: torch.FloatTensor,
226
+ image_rotary_emb=None,
227
+ attention_mask: Optional[torch.Tensor] = None,
228
+ joint_attention_kwargs: dict = {},
229
+ ):
230
+ residual = hidden_states
231
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
232
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
233
+
234
+ if attention_mask is not None:
235
+ attention_mask = expand_flux_attention_mask(
236
+ hidden_states,
237
+ attention_mask,
238
+ )
239
+
240
+ # Adding ability to pass hidden state info to IP Adapter
241
+ ip_encoder_hidden_states = None
242
+ ip_layer_scale = 1.0
243
+ if 'ip_hidden_states' in joint_attention_kwargs:
244
+ ip_encoder_hidden_states = joint_attention_kwargs['ip_hidden_states']
245
+ if 'ip_layer_scale' in joint_attention_kwargs:
246
+ ip_layer_scale = joint_attention_kwargs['ip_layer_scale']
247
+
248
+ # Attention.
249
+ attn_output = self.attn(
250
+ hidden_states=norm_hidden_states,
251
+ image_rotary_emb=image_rotary_emb,
252
+ attention_mask=attention_mask,
253
+ ip_encoder_hidden_states=ip_encoder_hidden_states,
254
+ layer_scale=ip_layer_scale,
255
+ )
256
+ #attn_output = self.attn(
257
+ # hidden_states=norm_hidden_states,
258
+ # image_rotary_emb=image_rotary_emb,
259
+ # attention_mask=attention_mask,
260
+ #)
261
+
262
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
263
+ gate = gate.unsqueeze(1)
264
+ hidden_states = gate * self.proj_out(hidden_states)
265
+ hidden_states = residual + hidden_states
266
+
267
+ if hidden_states.dtype == torch.float16:
268
+ hidden_states = hidden_states.clip(-65504, 65504)
269
+
270
+ return hidden_states
271
+
272
+
273
+ @maybe_allow_in_graph
274
+ class FluxTransformerBlock(nn.Module):
275
+ r"""
276
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
277
+
278
+ Reference: https://arxiv.org/abs/2403.03206
279
+
280
+ Parameters:
281
+ dim (`int`): The number of channels in the input and output.
282
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
283
+ attention_head_dim (`int`): The number of channels in each head.
284
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
285
+ processing of `context` conditions.
286
+ """
287
+
288
+ def __init__(
289
+ self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6
290
+ ):
291
+ super().__init__()
292
+
293
+ self.norm1 = AdaLayerNormZero(dim)
294
+
295
+ self.norm1_context = AdaLayerNormZero(dim)
296
+
297
+ if hasattr(F, "scaled_dot_product_attention"):
298
+ processor = FluxAttnProcessor2_0()
299
+ else:
300
+ raise ValueError(
301
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
302
+ )
303
+ self.attn = Attention(
304
+ query_dim=dim,
305
+ cross_attention_dim=None,
306
+ added_kv_proj_dim=dim,
307
+ dim_head=attention_head_dim,
308
+ heads=num_attention_heads,
309
+ out_dim=dim,
310
+ context_pre_only=False,
311
+ bias=True,
312
+ processor=processor,
313
+ qk_norm=qk_norm,
314
+ eps=eps,
315
+ )
316
+
317
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
318
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
319
+
320
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
321
+ self.ff_context = FeedForward(
322
+ dim=dim, dim_out=dim, activation_fn="gelu-approximate"
323
+ )
324
+
325
+ # let chunk size default to None
326
+ self._chunk_size = None
327
+ self._chunk_dim = 0
328
+
329
+ def forward(
330
+ self,
331
+ hidden_states: torch.FloatTensor,
332
+ encoder_hidden_states: torch.FloatTensor,
333
+ temb: torch.FloatTensor,
334
+ image_rotary_emb=None,
335
+ attention_mask: Optional[torch.Tensor] = None,
336
+ joint_attention_kwargs: dict = {},
337
+ ):
338
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
339
+ hidden_states, emb=temb
340
+ )
341
+
342
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
343
+ self.norm1_context(encoder_hidden_states, emb=temb)
344
+ )
345
+
346
+ if attention_mask is not None:
347
+ attention_mask = expand_flux_attention_mask(
348
+ torch.cat([encoder_hidden_states, hidden_states], dim=1),
349
+ attention_mask,
350
+ )
351
+
352
+ # Adding ability to pass hidden state info to IP Adapter
353
+ ip_encoder_hidden_states = None
354
+ ip_layer_scale = 1.0
355
+ if 'ip_hidden_states' in joint_attention_kwargs:
356
+ ip_encoder_hidden_states = joint_attention_kwargs['ip_hidden_states']
357
+ if 'ip_layer_scale' in joint_attention_kwargs:
358
+ ip_layer_scale = joint_attention_kwargs['ip_layer_scale']
359
+
360
+ # Attention.
361
+ attention_outputs = self.attn(
362
+ hidden_states=norm_hidden_states,
363
+ encoder_hidden_states=norm_encoder_hidden_states,
364
+ image_rotary_emb=image_rotary_emb,
365
+ attention_mask=attention_mask,
366
+ ip_encoder_hidden_states=ip_encoder_hidden_states,
367
+ layer_scale=ip_layer_scale,
368
+ )
369
+ if len(attention_outputs) == 2:
370
+ attn_output, context_attn_output = attention_outputs
371
+ elif len(attention_outputs) == 3:
372
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
373
+
374
+ # Process attention outputs for the `hidden_states`.
375
+ attn_output = gate_msa.unsqueeze(1) * attn_output
376
+ hidden_states = hidden_states + attn_output
377
+
378
+ norm_hidden_states = self.norm2(hidden_states)
379
+ norm_hidden_states = (
380
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
381
+ )
382
+
383
+ ff_output = self.ff(norm_hidden_states)
384
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
385
+
386
+ hidden_states = hidden_states + ff_output
387
+
388
+ # Removing this, was casuing error after adding ip_adapter
389
+ #if len(attention_outputs) == 3:
390
+ # hidden_states = hidden_states + ip_attn_output
391
+
392
+ # Process attention outputs for the `encoder_hidden_states`.
393
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
394
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
395
+
396
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
397
+ norm_encoder_hidden_states = (
398
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
399
+ + c_shift_mlp[:, None]
400
+ )
401
+
402
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
403
+ encoder_hidden_states = (
404
+ encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
405
+ )
406
+
407
+ if encoder_hidden_states.dtype == torch.float16:
408
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
409
+
410
+ return encoder_hidden_states, hidden_states
411
+
412
+ class LibreFluxTransformer2DModel(
413
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
414
+ ):
415
+ """
416
+ The Transformer model introduced in Flux.
417
+
418
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
419
+
420
+ Parameters:
421
+ patch_size (`int`): Patch size to turn the input data into small patches.
422
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
423
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
424
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
425
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
426
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
427
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
428
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
429
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
430
+ """
431
+
432
+ _supports_gradient_checkpointing = True
433
+
434
+ @register_to_config
435
+ def __init__(
436
+ self,
437
+ patch_size: int = 1,
438
+ in_channels: int = 64,
439
+ num_layers: int = 19,
440
+ num_single_layers: int = 38,
441
+ attention_head_dim: int = 128,
442
+ num_attention_heads: int = 24,
443
+ joint_attention_dim: int = 4096,
444
+ pooled_projection_dim: int = 768,
445
+ guidance_embeds: bool = False,
446
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
447
+ ):
448
+ super().__init__()
449
+ self.out_channels = in_channels
450
+ self.inner_dim = (
451
+ self.config.num_attention_heads * self.config.attention_head_dim
452
+ )
453
+
454
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
455
+ text_time_guidance_cls = (
456
+ CombinedTimestepGuidanceTextProjEmbeddings ### 3 input forward (timestep, guidance, pooled_projection)
457
+ if guidance_embeds
458
+ else CombinedTimestepTextProjEmbeddings #### 2 input forward (timestep, pooled_projection)
459
+ )
460
+ self.time_text_embed = text_time_guidance_cls(
461
+ embedding_dim=self.inner_dim,
462
+ pooled_projection_dim=self.config.pooled_projection_dim,
463
+ )
464
+
465
+ self.context_embedder = nn.Linear(
466
+ self.config.joint_attention_dim, self.inner_dim
467
+ )
468
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
469
+
470
+ self.transformer_blocks = nn.ModuleList(
471
+ [
472
+ FluxTransformerBlock(
473
+ dim=self.inner_dim,
474
+ num_attention_heads=self.config.num_attention_heads,
475
+ attention_head_dim=self.config.attention_head_dim,
476
+ )
477
+ for i in range(self.config.num_layers)
478
+ ]
479
+ )
480
+
481
+ self.single_transformer_blocks = nn.ModuleList(
482
+ [
483
+ FluxSingleTransformerBlock(
484
+ dim=self.inner_dim,
485
+ num_attention_heads=self.config.num_attention_heads,
486
+ attention_head_dim=self.config.attention_head_dim,
487
+ )
488
+ for i in range(self.config.num_single_layers)
489
+ ]
490
+ )
491
+
492
+ self.norm_out = AdaLayerNormContinuous(
493
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
494
+ )
495
+ self.proj_out = nn.Linear(
496
+ self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
497
+ )
498
+
499
+ self.gradient_checkpointing = False
500
+ # added for users to disable checkpointing every nth step
501
+ self.gradient_checkpointing_interval = None
502
+
503
+ def set_gradient_checkpointing_interval(self, value: int):
504
+ self.gradient_checkpointing_interval = value
505
+
506
+ @property
507
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
508
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
509
+ r"""
510
+ Returns:
511
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
512
+ indexed by its weight name.
513
+ """
514
+ # set recursively
515
+ processors = {}
516
+
517
+ def fn_recursive_add_processors(
518
+ name: str,
519
+ module: torch.nn.Module,
520
+ processors: Dict[str, AttentionProcessor],
521
+ ):
522
+ if hasattr(module, "get_processor"):
523
+ processors[f"{name}.processor"] = module.get_processor()
524
+
525
+ for sub_name, child in module.named_children():
526
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
527
+
528
+ return processors
529
+
530
+ for name, module in self.named_children():
531
+ fn_recursive_add_processors(name, module, processors)
532
+
533
+ return processors
534
+
535
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
536
+ def set_attn_processor(
537
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
538
+ ):
539
+ r"""
540
+ Sets the attention processor to use to compute attention.
541
+
542
+ Parameters:
543
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
544
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
545
+ for **all** `Attention` layers.
546
+
547
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
548
+ processor. This is strongly recommended when setting trainable attention processors.
549
+
550
+ """
551
+ count = len(self.attn_processors.keys())
552
+
553
+ if isinstance(processor, dict) and len(processor) != count:
554
+ raise ValueError(
555
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
556
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
557
+ )
558
+
559
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
560
+ if hasattr(module, "set_processor"):
561
+ if not isinstance(processor, dict):
562
+ module.set_processor(processor)
563
+ else:
564
+ module.set_processor(processor.pop(f"{name}.processor"))
565
+
566
+ for sub_name, child in module.named_children():
567
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
568
+
569
+ for name, module in self.named_children():
570
+ fn_recursive_attn_processor(name, module, processor)
571
+
572
+ def forward(
573
+ self,
574
+ hidden_states: torch.Tensor,
575
+ encoder_hidden_states: torch.Tensor = None,
576
+ pooled_projections: torch.Tensor = None,
577
+ timestep: torch.LongTensor = None,
578
+ img_ids: torch.Tensor = None,
579
+ txt_ids: torch.Tensor = None,
580
+ guidance: torch.Tensor = None,
581
+ joint_attention_kwargs: dict = {},
582
+ controlnet_block_samples=None,
583
+ controlnet_single_block_samples=None,
584
+ return_dict: bool = True,
585
+ attention_mask: Optional[torch.Tensor] = None,
586
+ controlnet_blocks_repeat: bool = False,
587
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
588
+ """
589
+ The [`FluxTransformer2DModel`] forward method.
590
+
591
+ Args:
592
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
593
+ Input `hidden_states`.
594
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
595
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
596
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
597
+ from the embeddings of input conditions.
598
+ timestep ( `torch.LongTensor`):
599
+ Used to indicate denoising step.
600
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
601
+ A list of tensors that if specified are added to the residuals of transformer blocks.
602
+ joint_attention_kwargs (`dict`, *optional*):
603
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
604
+ `self.processor` in
605
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
606
+ return_dict (`bool`, *optional*, defaults to `True`):
607
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
608
+ tuple.
609
+
610
+ Returns:
611
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
612
+ `tuple` where the first element is the sample tensor.
613
+ """
614
+ if joint_attention_kwargs is not None:
615
+ joint_attention_kwargs = joint_attention_kwargs.copy()
616
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
617
+ else:
618
+ lora_scale = 1.0
619
+
620
+ if USE_PEFT_BACKEND:
621
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
622
+ scale_lora_layers(self, lora_scale)
623
+ else:
624
+ if (
625
+ joint_attention_kwargs is not None
626
+ and joint_attention_kwargs.get("scale", None) is not None
627
+ ):
628
+ logger.warning(
629
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
630
+ )
631
+ hidden_states = self.x_embedder(hidden_states)
632
+
633
+ timestep = timestep.to(hidden_states.dtype) * 1000
634
+ if guidance is not None:
635
+ guidance = guidance.to(hidden_states.dtype) * 1000
636
+ else:
637
+ guidance = None
638
+
639
+ #print( self.time_text_embed)
640
+ temb = (
641
+ self.time_text_embed(timestep,pooled_projections)
642
+ # Edit 1 # Charlie NOT NEEDED - UNDONE
643
+ if guidance is None
644
+ else self.time_text_embed(timestep, guidance, pooled_projections)
645
+ )
646
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
647
+
648
+ if txt_ids.ndim == 3:
649
+ txt_ids = txt_ids[0]
650
+ if img_ids.ndim == 3:
651
+ img_ids = img_ids[0]
652
+
653
+ ids = torch.cat((txt_ids, img_ids), dim=0)
654
+
655
+ image_rotary_emb = self.pos_embed(ids)
656
+
657
+
658
+ for index_block, block in enumerate(self.transformer_blocks):
659
+ if (
660
+ self.training
661
+ and self.gradient_checkpointing
662
+ and (
663
+ self.gradient_checkpointing_interval is None
664
+ or index_block % self.gradient_checkpointing_interval == 0
665
+ )
666
+ ):
667
+
668
+ def create_custom_forward(module, return_dict=None):
669
+ def custom_forward(*inputs):
670
+ if return_dict is not None:
671
+ return module(*inputs, return_dict=return_dict)
672
+ else:
673
+ return module(*inputs)
674
+
675
+ return custom_forward
676
+
677
+ ckpt_kwargs: Dict[str, Any] = (
678
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
679
+ )
680
+ encoder_hidden_states, hidden_states = (
681
+ torch.utils.checkpoint.checkpoint(
682
+ create_custom_forward(block),
683
+ hidden_states,
684
+ encoder_hidden_states,
685
+ temb,
686
+ image_rotary_emb,
687
+ attention_mask,
688
+ joint_attention_kwargs, # Add this line
689
+ **ckpt_kwargs,
690
+ )
691
+ )
692
+
693
+ else:
694
+ encoder_hidden_states, hidden_states = block(
695
+ hidden_states=hidden_states,
696
+ encoder_hidden_states=encoder_hidden_states,
697
+ temb=temb,
698
+ image_rotary_emb=image_rotary_emb,
699
+ attention_mask=attention_mask,
700
+ joint_attention_kwargs=joint_attention_kwargs # Add this line
701
+ )
702
+
703
+ # controlnet residual
704
+ if controlnet_block_samples is not None:
705
+ interval_control = len(self.transformer_blocks) / len(
706
+ controlnet_block_samples
707
+ )
708
+ interval_control = int(np.ceil(interval_control))
709
+ # For Xlabs ControlNet.
710
+ if controlnet_blocks_repeat:
711
+ hidden_states = (
712
+ hidden_states
713
+ + controlnet_block_samples[
714
+ index_block % len(controlnet_block_samples)
715
+ ]
716
+ )
717
+ else:
718
+ hidden_states = (
719
+ hidden_states
720
+ + controlnet_block_samples[index_block // interval_control]
721
+ )
722
+
723
+ # Flux places the text tokens in front of the image tokens in the
724
+ # sequence.
725
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
726
+
727
+ for index_block, block in enumerate(self.single_transformer_blocks):
728
+ if (
729
+ self.training
730
+ and self.gradient_checkpointing
731
+ or (
732
+ self.gradient_checkpointing_interval is not None
733
+ and index_block % self.gradient_checkpointing_interval == 0
734
+ )
735
+ ):
736
+
737
+ def create_custom_forward(module, return_dict=None):
738
+ def custom_forward(*inputs):
739
+ if return_dict is not None:
740
+ return module(*inputs, return_dict=return_dict)
741
+ else:
742
+ return module(*inputs)
743
+
744
+ return custom_forward
745
+
746
+ ckpt_kwargs: Dict[str, Any] = (
747
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
748
+ )
749
+ hidden_states = torch.utils.checkpoint.checkpoint(
750
+ create_custom_forward(block),
751
+ hidden_states,
752
+ temb,
753
+ image_rotary_emb,
754
+ attention_mask,
755
+ joint_attention_kwargs, # Add this line
756
+
757
+ **ckpt_kwargs,
758
+ )
759
+
760
+ else:
761
+ hidden_states = block(
762
+ hidden_states=hidden_states,
763
+ temb=temb,
764
+ image_rotary_emb=image_rotary_emb,
765
+ attention_mask=attention_mask,
766
+ joint_attention_kwargs= joint_attention_kwargs, # Add this line
767
+
768
+ )
769
+
770
+ # controlnet residual
771
+ if controlnet_single_block_samples is not None:
772
+ interval_control = len(self.single_transformer_blocks) / len(
773
+ controlnet_single_block_samples
774
+ )
775
+ interval_control = int(np.ceil(interval_control))
776
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
777
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
778
+ + controlnet_single_block_samples[index_block // interval_control]
779
+ )
780
+
781
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
782
+
783
+ hidden_states = self.norm_out(hidden_states, temb)
784
+ output = self.proj_out(hidden_states)
785
+
786
+ if USE_PEFT_BACKEND:
787
+ # remove `lora_scale` from each PEFT layer
788
+ unscale_lora_layers(self, lora_scale)
789
+
790
+ if not return_dict:
791
+ return (output,)
792
+
793
+ return Transformer2DModelOutput(sample=output)
794
+
vae/config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.30.0.dev0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "force_upcast": true,
18
+ "in_channels": 3,
19
+ "latent_channels": 16,
20
+ "latents_mean": null,
21
+ "latents_std": null,
22
+ "layers_per_block": 2,
23
+ "mid_block_add_attention": true,
24
+ "norm_num_groups": 32,
25
+ "out_channels": 3,
26
+ "sample_size": 1024,
27
+ "scaling_factor": 0.3611,
28
+ "shift_factor": 0.1159,
29
+ "up_block_types": [
30
+ "UpDecoderBlock2D",
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D"
34
+ ],
35
+ "use_post_quant_conv": false,
36
+ "use_quant_conv": false
37
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5b59a26851551b67ae1fe58d32e76486e1e812def4696a4bea97f16604d40a3
3
+ size 167666902