vidfom commited on
Commit
dc959fc
·
verified ·
1 Parent(s): 8f104e3

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. content/flux/.gitattributes +36 -0
  3. content/flux/folder_paths.py +270 -0
  4. content/flux/latent_preview.py +94 -0
  5. content/flux/models/clip/clip_l.safetensors +3 -0
  6. content/flux/models/clip/t5xxl_fp8_e4m3fn.safetensors +3 -0
  7. content/flux/models/unet/flux1-schnell.safetensors +3 -0
  8. content/flux/models/vae/ae.sft +3 -0
  9. content/flux/node_helpers.py +37 -0
  10. content/flux/nodes.py +2073 -0
  11. content/flux/totoro/__pycache__/cli_args.cpython-311.pyc +0 -0
  12. content/flux/totoro/__pycache__/diffusers_load.cpython-311.pyc +0 -0
  13. content/flux/totoro/__pycache__/model_management.cpython-311.pyc +0 -0
  14. content/flux/totoro/__pycache__/options.cpython-311.pyc +0 -0
  15. content/flux/totoro/__pycache__/sd.cpython-311.pyc +0 -0
  16. content/flux/totoro/checkpoint_pickle.py +13 -0
  17. content/flux/totoro/cldm/cldm.py +437 -0
  18. content/flux/totoro/cldm/control_types.py +10 -0
  19. content/flux/totoro/cldm/mmdit.py +77 -0
  20. content/flux/totoro/cli_args.py +180 -0
  21. content/flux/totoro/clip_config_bigg.json +23 -0
  22. content/flux/totoro/clip_model.py +196 -0
  23. content/flux/totoro/clip_vision.py +121 -0
  24. content/flux/totoro/clip_vision_config_g.json +18 -0
  25. content/flux/totoro/clip_vision_config_h.json +18 -0
  26. content/flux/totoro/clip_vision_config_vitl.json +18 -0
  27. content/flux/totoro/clip_vision_config_vitl_336.json +18 -0
  28. content/flux/totoro/conds.py +83 -0
  29. content/flux/totoro/controlnet.py +610 -0
  30. content/flux/totoro/diffusers_convert.py +281 -0
  31. content/flux/totoro/diffusers_load.py +36 -0
  32. content/flux/totoro/extra_samplers/uni_pc.py +875 -0
  33. content/flux/totoro/gligen.py +343 -0
  34. content/flux/totoro/k_diffusion/deis.py +121 -0
  35. content/flux/totoro/k_diffusion/sampling.py +1049 -0
  36. content/flux/totoro/k_diffusion/utils.py +313 -0
  37. content/flux/totoro/latent_formats.py +152 -0
  38. content/flux/totoro/ldm/audio/autoencoder.py +282 -0
  39. content/flux/totoro/ldm/audio/dit.py +891 -0
  40. content/flux/totoro/ldm/audio/embedders.py +108 -0
  41. content/flux/totoro/ldm/aura/mmdit.py +480 -0
  42. content/flux/totoro/ldm/cascade/common.py +154 -0
  43. content/flux/totoro/ldm/cascade/controlnet.py +93 -0
  44. content/flux/totoro/ldm/cascade/stage_a.py +255 -0
  45. content/flux/totoro/ldm/cascade/stage_b.py +256 -0
  46. content/flux/totoro/ldm/cascade/stage_c.py +273 -0
  47. content/flux/totoro/ldm/cascade/stage_c_coder.py +95 -0
  48. content/flux/totoro/ldm/flux/layers.py +256 -0
  49. content/flux/totoro/ldm/flux/math.py +35 -0
  50. content/flux/totoro/ldm/flux/model.py +138 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ content/flux/models/vae/ae.sft filter=lfs diff=lfs merge=lfs -text
content/flux/.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ models/vae/ae.sft filter=lfs diff=lfs merge=lfs -text
content/flux/folder_paths.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import logging
4
+ from typing import Set, List, Dict, Tuple
5
+
6
+ supported_pt_extensions: Set[str] = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'])
7
+
8
+ SupportedFileExtensionsType = Set[str]
9
+ ScanPathType = List[str]
10
+ folder_names_and_paths: Dict[str, Tuple[ScanPathType, SupportedFileExtensionsType]] = {}
11
+
12
+ base_path = os.path.dirname(os.path.realpath(__file__))
13
+ models_dir = os.path.join(base_path, "models")
14
+ folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_pt_extensions)
15
+ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"])
16
+
17
+ folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
18
+ folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
19
+ folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
20
+ folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions)
21
+ folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
22
+ folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
23
+ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
24
+ folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
25
+ folder_names_and_paths["vae_approx"] = ([os.path.join(models_dir, "vae_approx")], supported_pt_extensions)
26
+
27
+ folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
28
+ folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)
29
+
30
+ folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
31
+
32
+ folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], set())
33
+
34
+ folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
35
+
36
+ folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions)
37
+
38
+ folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
39
+
40
+ output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
41
+ temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
42
+ input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
43
+ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user")
44
+
45
+ filename_list_cache = {}
46
+
47
+ if not os.path.exists(input_directory):
48
+ try:
49
+ os.makedirs(input_directory)
50
+ except:
51
+ logging.error("Failed to create input directory")
52
+
53
+ def set_output_directory(output_dir):
54
+ global output_directory
55
+ output_directory = output_dir
56
+
57
+ def set_temp_directory(temp_dir):
58
+ global temp_directory
59
+ temp_directory = temp_dir
60
+
61
+ def set_input_directory(input_dir):
62
+ global input_directory
63
+ input_directory = input_dir
64
+
65
+ def get_output_directory():
66
+ global output_directory
67
+ return output_directory
68
+
69
+ def get_temp_directory():
70
+ global temp_directory
71
+ return temp_directory
72
+
73
+ def get_input_directory():
74
+ global input_directory
75
+ return input_directory
76
+
77
+
78
+ #NOTE: used in http server so don't put folders that should not be accessed remotely
79
+ def get_directory_by_type(type_name):
80
+ if type_name == "output":
81
+ return get_output_directory()
82
+ if type_name == "temp":
83
+ return get_temp_directory()
84
+ if type_name == "input":
85
+ return get_input_directory()
86
+ return None
87
+
88
+
89
+ # determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
90
+ # otherwise use default_path as base_dir
91
+ def annotated_filepath(name):
92
+ if name.endswith("[output]"):
93
+ base_dir = get_output_directory()
94
+ name = name[:-9]
95
+ elif name.endswith("[input]"):
96
+ base_dir = get_input_directory()
97
+ name = name[:-8]
98
+ elif name.endswith("[temp]"):
99
+ base_dir = get_temp_directory()
100
+ name = name[:-7]
101
+ else:
102
+ return name, None
103
+
104
+ return name, base_dir
105
+
106
+
107
+ def get_annotated_filepath(name, default_dir=None):
108
+ name, base_dir = annotated_filepath(name)
109
+
110
+ if base_dir is None:
111
+ if default_dir is not None:
112
+ base_dir = default_dir
113
+ else:
114
+ base_dir = get_input_directory() # fallback path
115
+
116
+ return os.path.join(base_dir, name)
117
+
118
+
119
+ def exists_annotated_filepath(name):
120
+ name, base_dir = annotated_filepath(name)
121
+
122
+ if base_dir is None:
123
+ base_dir = get_input_directory() # fallback path
124
+
125
+ filepath = os.path.join(base_dir, name)
126
+ return os.path.exists(filepath)
127
+
128
+
129
+ def add_model_folder_path(folder_name, full_folder_path):
130
+ global folder_names_and_paths
131
+ if folder_name in folder_names_and_paths:
132
+ folder_names_and_paths[folder_name][0].append(full_folder_path)
133
+ else:
134
+ folder_names_and_paths[folder_name] = ([full_folder_path], set())
135
+
136
+ def get_folder_paths(folder_name):
137
+ return folder_names_and_paths[folder_name][0][:]
138
+
139
+ def recursive_search(directory, excluded_dir_names=None):
140
+ if not os.path.isdir(directory):
141
+ return [], {}
142
+
143
+ if excluded_dir_names is None:
144
+ excluded_dir_names = []
145
+
146
+ result = []
147
+ dirs = {}
148
+
149
+ # Attempt to add the initial directory to dirs with error handling
150
+ try:
151
+ dirs[directory] = os.path.getmtime(directory)
152
+ except FileNotFoundError:
153
+ logging.warning(f"Warning: Unable to access {directory}. Skipping this path.")
154
+
155
+ logging.debug("recursive file list on directory {}".format(directory))
156
+ for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
157
+ subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
158
+ for file_name in filenames:
159
+ relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
160
+ result.append(relative_path)
161
+
162
+ for d in subdirs:
163
+ path = os.path.join(dirpath, d)
164
+ try:
165
+ dirs[path] = os.path.getmtime(path)
166
+ except FileNotFoundError:
167
+ logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
168
+ continue
169
+ logging.debug("found {} files".format(len(result)))
170
+ return result, dirs
171
+
172
+ def filter_files_extensions(files, extensions):
173
+ return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))
174
+
175
+
176
+
177
+ def get_full_path(folder_name, filename):
178
+ global folder_names_and_paths
179
+ if folder_name not in folder_names_and_paths:
180
+ return None
181
+ folders = folder_names_and_paths[folder_name]
182
+ filename = os.path.relpath(os.path.join("/", filename), "/")
183
+ for x in folders[0]:
184
+ full_path = os.path.join(x, filename)
185
+ if os.path.isfile(full_path):
186
+ return full_path
187
+ elif os.path.islink(full_path):
188
+ logging.warning("WARNING path {} exists but doesn't link anywhere, skipping.".format(full_path))
189
+
190
+ return None
191
+
192
+ def get_filename_list_(folder_name):
193
+ global folder_names_and_paths
194
+ output_list = set()
195
+ folders = folder_names_and_paths[folder_name]
196
+ output_folders = {}
197
+ for x in folders[0]:
198
+ files, folders_all = recursive_search(x, excluded_dir_names=[".git"])
199
+ output_list.update(filter_files_extensions(files, folders[1]))
200
+ output_folders = {**output_folders, **folders_all}
201
+
202
+ return (sorted(list(output_list)), output_folders, time.perf_counter())
203
+
204
+ def cached_filename_list_(folder_name):
205
+ global filename_list_cache
206
+ global folder_names_and_paths
207
+ if folder_name not in filename_list_cache:
208
+ return None
209
+ out = filename_list_cache[folder_name]
210
+
211
+ for x in out[1]:
212
+ time_modified = out[1][x]
213
+ folder = x
214
+ if os.path.getmtime(folder) != time_modified:
215
+ return None
216
+
217
+ folders = folder_names_and_paths[folder_name]
218
+ for x in folders[0]:
219
+ if os.path.isdir(x):
220
+ if x not in out[1]:
221
+ return None
222
+
223
+ return out
224
+
225
+ def get_filename_list(folder_name):
226
+ out = cached_filename_list_(folder_name)
227
+ if out is None:
228
+ out = get_filename_list_(folder_name)
229
+ global filename_list_cache
230
+ filename_list_cache[folder_name] = out
231
+ return list(out[0])
232
+
233
+ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
234
+ def map_filename(filename):
235
+ prefix_len = len(os.path.basename(filename_prefix))
236
+ prefix = filename[:prefix_len + 1]
237
+ try:
238
+ digits = int(filename[prefix_len + 1:].split('_')[0])
239
+ except:
240
+ digits = 0
241
+ return (digits, prefix)
242
+
243
+ def compute_vars(input, image_width, image_height):
244
+ input = input.replace("%width%", str(image_width))
245
+ input = input.replace("%height%", str(image_height))
246
+ return input
247
+
248
+ filename_prefix = compute_vars(filename_prefix, image_width, image_height)
249
+
250
+ subfolder = os.path.dirname(os.path.normpath(filename_prefix))
251
+ filename = os.path.basename(os.path.normpath(filename_prefix))
252
+
253
+ full_output_folder = os.path.join(output_dir, subfolder)
254
+
255
+ if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
256
+ err = "**** ERROR: Saving image outside the output folder is not allowed." + \
257
+ "\n full_output_folder: " + os.path.abspath(full_output_folder) + \
258
+ "\n output_dir: " + output_dir + \
259
+ "\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))
260
+ logging.error(err)
261
+ raise Exception(err)
262
+
263
+ try:
264
+ counter = max(filter(lambda a: os.path.normcase(a[1][:-1]) == os.path.normcase(filename) and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
265
+ except ValueError:
266
+ counter = 1
267
+ except FileNotFoundError:
268
+ os.makedirs(full_output_folder, exist_ok=True)
269
+ counter = 1
270
+ return full_output_folder, filename, counter, subfolder, filename_prefix
content/flux/latent_preview.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import struct
4
+ import numpy as np
5
+ from totoro.cli_args import args, LatentPreviewMethod
6
+ from totoro.taesd.taesd import TAESD
7
+ import totoro.model_management
8
+ import folder_paths
9
+ import totoro.utils
10
+ import logging
11
+
12
+ MAX_PREVIEW_RESOLUTION = 512
13
+
14
+ def preview_to_image(latent_image):
15
+ latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
16
+ .mul(0xFF) # to 0..255
17
+ ).to(device="cpu", dtype=torch.uint8, non_blocking=totoro.model_management.device_supports_non_blocking(latent_image.device))
18
+
19
+ return Image.fromarray(latents_ubyte.numpy())
20
+
21
+ class LatentPreviewer:
22
+ def decode_latent_to_preview(self, x0):
23
+ pass
24
+
25
+ def decode_latent_to_preview_image(self, preview_format, x0):
26
+ preview_image = self.decode_latent_to_preview(x0)
27
+ return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
28
+
29
+ class TAESDPreviewerImpl(LatentPreviewer):
30
+ def __init__(self, taesd):
31
+ self.taesd = taesd
32
+
33
+ def decode_latent_to_preview(self, x0):
34
+ x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2)
35
+ return preview_to_image(x_sample)
36
+
37
+
38
+ class Latent2RGBPreviewer(LatentPreviewer):
39
+ def __init__(self, latent_rgb_factors):
40
+ self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
41
+
42
+ def decode_latent_to_preview(self, x0):
43
+ self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
44
+ latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
45
+ return preview_to_image(latent_image)
46
+
47
+
48
+ def get_previewer(device, latent_format):
49
+ previewer = None
50
+ method = args.preview_method
51
+ if method != LatentPreviewMethod.NoPreviews:
52
+ # TODO previewer methods
53
+ taesd_decoder_path = None
54
+ if latent_format.taesd_decoder_name is not None:
55
+ taesd_decoder_path = next(
56
+ (fn for fn in folder_paths.get_filename_list("vae_approx")
57
+ if fn.startswith(latent_format.taesd_decoder_name)),
58
+ ""
59
+ )
60
+ taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path)
61
+
62
+ if method == LatentPreviewMethod.Auto:
63
+ method = LatentPreviewMethod.Latent2RGB
64
+
65
+ if method == LatentPreviewMethod.TAESD:
66
+ if taesd_decoder_path:
67
+ taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
68
+ previewer = TAESDPreviewerImpl(taesd)
69
+ else:
70
+ logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
71
+
72
+ if previewer is None:
73
+ if latent_format.latent_rgb_factors is not None:
74
+ previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
75
+ return previewer
76
+
77
+ def prepare_callback(model, steps, x0_output_dict=None):
78
+ preview_format = "JPEG"
79
+ if preview_format not in ["JPEG", "PNG"]:
80
+ preview_format = "JPEG"
81
+
82
+ previewer = get_previewer(model.load_device, model.model.latent_format)
83
+
84
+ pbar = totoro.utils.ProgressBar(steps)
85
+ def callback(step, x0, x, total_steps):
86
+ if x0_output_dict is not None:
87
+ x0_output_dict["x0"] = x0
88
+
89
+ preview_bytes = None
90
+ if previewer:
91
+ preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
92
+ pbar.update_absolute(step + 1, total_steps, preview_bytes)
93
+ return callback
94
+
content/flux/models/clip/clip_l.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:660c6f5b1abae9dc498ac2d21e1347d2abdb0cf6c0c0c8576cd796491d9a6cdd
3
+ size 246144152
content/flux/models/clip/t5xxl_fp8_e4m3fn.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d330da4816157540d6bb7838bf63a0f02f573fc48ca4d8de34bb0cbfd514f09
3
+ size 4893934904
content/flux/models/unet/flux1-schnell.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9403429e0052277ac2a87ad800adece5481eecefd9ed334e1f348723621d2a0a
3
+ size 23782506688
content/flux/models/vae/ae.sft ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:afc8e28272cd15db3919bacdb6918ce9c1ed22e96cb12c4d5ed0fba823529e38
3
+ size 335304388
content/flux/node_helpers.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+
3
+ from totoro.cli_args import args
4
+
5
+ from PIL import ImageFile, UnidentifiedImageError
6
+
7
+ def conditioning_set_values(conditioning, values={}):
8
+ c = []
9
+ for t in conditioning:
10
+ n = [t[0], t[1].copy()]
11
+ for k in values:
12
+ n[1][k] = values[k]
13
+ c.append(n)
14
+
15
+ return c
16
+
17
+ def pillow(fn, arg):
18
+ prev_value = None
19
+ try:
20
+ x = fn(arg)
21
+ except (OSError, UnidentifiedImageError, ValueError): #PIL issues #4472 and #2445, also fixes totoroUI issue #3416
22
+ prev_value = ImageFile.LOAD_TRUNCATED_IMAGES
23
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
24
+ x = fn(arg)
25
+ finally:
26
+ if prev_value is not None:
27
+ ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
28
+ return x
29
+
30
+ def hasher():
31
+ hashfuncs = {
32
+ "md5": hashlib.md5,
33
+ "sha1": hashlib.sha1,
34
+ "sha256": hashlib.sha256,
35
+ "sha512": hashlib.sha512
36
+ }
37
+ return hashfuncs[args.default_hashing_function]
content/flux/nodes.py ADDED
@@ -0,0 +1,2073 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import os
4
+ import sys
5
+ import json
6
+ import hashlib
7
+ import traceback
8
+ import math
9
+ import time
10
+ import random
11
+ import logging
12
+
13
+ from PIL import Image, ImageOps, ImageSequence, ImageFile
14
+ from PIL.PngImagePlugin import PngInfo
15
+
16
+ import numpy as np
17
+ import safetensors.torch
18
+
19
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "totoro"))
20
+
21
+ import totoro.diffusers_load
22
+ import totoro.samplers
23
+ import totoro.sample
24
+ import totoro.sd
25
+ import totoro.utils
26
+ import totoro.controlnet
27
+
28
+ import totoro.clip_vision
29
+
30
+ import totoro.model_management
31
+ from totoro.cli_args import args
32
+
33
+ import importlib
34
+
35
+ import folder_paths
36
+ import latent_preview
37
+ import node_helpers
38
+
39
+ def before_node_execution():
40
+ totoro.model_management.throw_exception_if_processing_interrupted()
41
+
42
+ def interrupt_processing(value=True):
43
+ totoro.model_management.interrupt_current_processing(value)
44
+
45
+ MAX_RESOLUTION=16384
46
+
47
+ class CLIPTextEncode:
48
+ @classmethod
49
+ def INPUT_TYPES(s):
50
+ return {"required": {"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", )}}
51
+ RETURN_TYPES = ("CONDITIONING",)
52
+ FUNCTION = "encode"
53
+
54
+ CATEGORY = "conditioning"
55
+
56
+ def encode(self, clip, text):
57
+ tokens = clip.tokenize(text)
58
+ output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
59
+ cond = output.pop("cond")
60
+ return ([[cond, output]], )
61
+
62
+ class ConditioningCombine:
63
+ @classmethod
64
+ def INPUT_TYPES(s):
65
+ return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
66
+ RETURN_TYPES = ("CONDITIONING",)
67
+ FUNCTION = "combine"
68
+
69
+ CATEGORY = "conditioning"
70
+
71
+ def combine(self, conditioning_1, conditioning_2):
72
+ return (conditioning_1 + conditioning_2, )
73
+
74
+ class ConditioningAverage :
75
+ @classmethod
76
+ def INPUT_TYPES(s):
77
+ return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ),
78
+ "conditioning_to_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
79
+ }}
80
+ RETURN_TYPES = ("CONDITIONING",)
81
+ FUNCTION = "addWeighted"
82
+
83
+ CATEGORY = "conditioning"
84
+
85
+ def addWeighted(self, conditioning_to, conditioning_from, conditioning_to_strength):
86
+ out = []
87
+
88
+ if len(conditioning_from) > 1:
89
+ logging.warning("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
90
+
91
+ cond_from = conditioning_from[0][0]
92
+ pooled_output_from = conditioning_from[0][1].get("pooled_output", None)
93
+
94
+ for i in range(len(conditioning_to)):
95
+ t1 = conditioning_to[i][0]
96
+ pooled_output_to = conditioning_to[i][1].get("pooled_output", pooled_output_from)
97
+ t0 = cond_from[:,:t1.shape[1]]
98
+ if t0.shape[1] < t1.shape[1]:
99
+ t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1)
100
+
101
+ tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength))
102
+ t_to = conditioning_to[i][1].copy()
103
+ if pooled_output_from is not None and pooled_output_to is not None:
104
+ t_to["pooled_output"] = torch.mul(pooled_output_to, conditioning_to_strength) + torch.mul(pooled_output_from, (1.0 - conditioning_to_strength))
105
+ elif pooled_output_from is not None:
106
+ t_to["pooled_output"] = pooled_output_from
107
+
108
+ n = [tw, t_to]
109
+ out.append(n)
110
+ return (out, )
111
+
112
+ class ConditioningConcat:
113
+ @classmethod
114
+ def INPUT_TYPES(s):
115
+ return {"required": {
116
+ "conditioning_to": ("CONDITIONING",),
117
+ "conditioning_from": ("CONDITIONING",),
118
+ }}
119
+ RETURN_TYPES = ("CONDITIONING",)
120
+ FUNCTION = "concat"
121
+
122
+ CATEGORY = "conditioning"
123
+
124
+ def concat(self, conditioning_to, conditioning_from):
125
+ out = []
126
+
127
+ if len(conditioning_from) > 1:
128
+ logging.warning("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
129
+
130
+ cond_from = conditioning_from[0][0]
131
+
132
+ for i in range(len(conditioning_to)):
133
+ t1 = conditioning_to[i][0]
134
+ tw = torch.cat((t1, cond_from),1)
135
+ n = [tw, conditioning_to[i][1].copy()]
136
+ out.append(n)
137
+
138
+ return (out, )
139
+
140
+ class ConditioningSetArea:
141
+ @classmethod
142
+ def INPUT_TYPES(s):
143
+ return {"required": {"conditioning": ("CONDITIONING", ),
144
+ "width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
145
+ "height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
146
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
147
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
148
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
149
+ }}
150
+ RETURN_TYPES = ("CONDITIONING",)
151
+ FUNCTION = "append"
152
+
153
+ CATEGORY = "conditioning"
154
+
155
+ def append(self, conditioning, width, height, x, y, strength):
156
+ c = node_helpers.conditioning_set_values(conditioning, {"area": (height // 8, width // 8, y // 8, x // 8),
157
+ "strength": strength,
158
+ "set_area_to_bounds": False})
159
+ return (c, )
160
+
161
+ class ConditioningSetAreaPercentage:
162
+ @classmethod
163
+ def INPUT_TYPES(s):
164
+ return {"required": {"conditioning": ("CONDITIONING", ),
165
+ "width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
166
+ "height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
167
+ "x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
168
+ "y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
169
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
170
+ }}
171
+ RETURN_TYPES = ("CONDITIONING",)
172
+ FUNCTION = "append"
173
+
174
+ CATEGORY = "conditioning"
175
+
176
+ def append(self, conditioning, width, height, x, y, strength):
177
+ c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", height, width, y, x),
178
+ "strength": strength,
179
+ "set_area_to_bounds": False})
180
+ return (c, )
181
+
182
+ class ConditioningSetAreaStrength:
183
+ @classmethod
184
+ def INPUT_TYPES(s):
185
+ return {"required": {"conditioning": ("CONDITIONING", ),
186
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
187
+ }}
188
+ RETURN_TYPES = ("CONDITIONING",)
189
+ FUNCTION = "append"
190
+
191
+ CATEGORY = "conditioning"
192
+
193
+ def append(self, conditioning, strength):
194
+ c = node_helpers.conditioning_set_values(conditioning, {"strength": strength})
195
+ return (c, )
196
+
197
+
198
+ class ConditioningSetMask:
199
+ @classmethod
200
+ def INPUT_TYPES(s):
201
+ return {"required": {"conditioning": ("CONDITIONING", ),
202
+ "mask": ("MASK", ),
203
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
204
+ "set_cond_area": (["default", "mask bounds"],),
205
+ }}
206
+ RETURN_TYPES = ("CONDITIONING",)
207
+ FUNCTION = "append"
208
+
209
+ CATEGORY = "conditioning"
210
+
211
+ def append(self, conditioning, mask, set_cond_area, strength):
212
+ set_area_to_bounds = False
213
+ if set_cond_area != "default":
214
+ set_area_to_bounds = True
215
+ if len(mask.shape) < 3:
216
+ mask = mask.unsqueeze(0)
217
+
218
+ c = node_helpers.conditioning_set_values(conditioning, {"mask": mask,
219
+ "set_area_to_bounds": set_area_to_bounds,
220
+ "mask_strength": strength})
221
+ return (c, )
222
+
223
+ class ConditioningZeroOut:
224
+ @classmethod
225
+ def INPUT_TYPES(s):
226
+ return {"required": {"conditioning": ("CONDITIONING", )}}
227
+ RETURN_TYPES = ("CONDITIONING",)
228
+ FUNCTION = "zero_out"
229
+
230
+ CATEGORY = "advanced/conditioning"
231
+
232
+ def zero_out(self, conditioning):
233
+ c = []
234
+ for t in conditioning:
235
+ d = t[1].copy()
236
+ pooled_output = d.get("pooled_output", None)
237
+ if pooled_output is not None:
238
+ d["pooled_output"] = torch.zeros_like(pooled_output)
239
+ n = [torch.zeros_like(t[0]), d]
240
+ c.append(n)
241
+ return (c, )
242
+
243
+ class ConditioningSetTimestepRange:
244
+ @classmethod
245
+ def INPUT_TYPES(s):
246
+ return {"required": {"conditioning": ("CONDITIONING", ),
247
+ "start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
248
+ "end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
249
+ }}
250
+ RETURN_TYPES = ("CONDITIONING",)
251
+ FUNCTION = "set_range"
252
+
253
+ CATEGORY = "advanced/conditioning"
254
+
255
+ def set_range(self, conditioning, start, end):
256
+ c = node_helpers.conditioning_set_values(conditioning, {"start_percent": start,
257
+ "end_percent": end})
258
+ return (c, )
259
+
260
+ class VAEDecode:
261
+ @classmethod
262
+ def INPUT_TYPES(s):
263
+ return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
264
+ RETURN_TYPES = ("IMAGE",)
265
+ FUNCTION = "decode"
266
+
267
+ CATEGORY = "latent"
268
+
269
+ def decode(self, vae, samples):
270
+ return (vae.decode(samples["samples"]), )
271
+
272
+ class VAEDecodeTiled:
273
+ @classmethod
274
+ def INPUT_TYPES(s):
275
+ return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
276
+ "tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
277
+ }}
278
+ RETURN_TYPES = ("IMAGE",)
279
+ FUNCTION = "decode"
280
+
281
+ CATEGORY = "_for_testing"
282
+
283
+ def decode(self, vae, samples, tile_size):
284
+ return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), )
285
+
286
+ class VAEEncode:
287
+ @classmethod
288
+ def INPUT_TYPES(s):
289
+ return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
290
+ RETURN_TYPES = ("LATENT",)
291
+ FUNCTION = "encode"
292
+
293
+ CATEGORY = "latent"
294
+
295
+ def encode(self, vae, pixels):
296
+ t = vae.encode(pixels[:,:,:,:3])
297
+ return ({"samples":t}, )
298
+
299
+ class VAEEncodeTiled:
300
+ @classmethod
301
+ def INPUT_TYPES(s):
302
+ return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
303
+ "tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
304
+ }}
305
+ RETURN_TYPES = ("LATENT",)
306
+ FUNCTION = "encode"
307
+
308
+ CATEGORY = "_for_testing"
309
+
310
+ def encode(self, vae, pixels, tile_size):
311
+ t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
312
+ return ({"samples":t}, )
313
+
314
+ class VAEEncodeForInpaint:
315
+ @classmethod
316
+ def INPUT_TYPES(s):
317
+ return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
318
+ RETURN_TYPES = ("LATENT",)
319
+ FUNCTION = "encode"
320
+
321
+ CATEGORY = "latent/inpaint"
322
+
323
+ def encode(self, vae, pixels, mask, grow_mask_by=6):
324
+ x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio
325
+ y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio
326
+ mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
327
+
328
+ pixels = pixels.clone()
329
+ if pixels.shape[1] != x or pixels.shape[2] != y:
330
+ x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2
331
+ y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2
332
+ pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
333
+ mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
334
+
335
+ #grow mask by a few pixels to keep things seamless in latent space
336
+ if grow_mask_by == 0:
337
+ mask_erosion = mask
338
+ else:
339
+ kernel_tensor = torch.ones((1, 1, grow_mask_by, grow_mask_by))
340
+ padding = math.ceil((grow_mask_by - 1) / 2)
341
+
342
+ mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=padding), 0, 1)
343
+
344
+ m = (1.0 - mask.round()).squeeze(1)
345
+ for i in range(3):
346
+ pixels[:,:,:,i] -= 0.5
347
+ pixels[:,:,:,i] *= m
348
+ pixels[:,:,:,i] += 0.5
349
+ t = vae.encode(pixels)
350
+
351
+ return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
352
+
353
+
354
+ class InpaintModelConditioning:
355
+ @classmethod
356
+ def INPUT_TYPES(s):
357
+ return {"required": {"positive": ("CONDITIONING", ),
358
+ "negative": ("CONDITIONING", ),
359
+ "vae": ("VAE", ),
360
+ "pixels": ("IMAGE", ),
361
+ "mask": ("MASK", ),
362
+ }}
363
+
364
+ RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
365
+ RETURN_NAMES = ("positive", "negative", "latent")
366
+ FUNCTION = "encode"
367
+
368
+ CATEGORY = "conditioning/inpaint"
369
+
370
+ def encode(self, positive, negative, pixels, vae, mask):
371
+ x = (pixels.shape[1] // 8) * 8
372
+ y = (pixels.shape[2] // 8) * 8
373
+ mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
374
+
375
+ orig_pixels = pixels
376
+ pixels = orig_pixels.clone()
377
+ if pixels.shape[1] != x or pixels.shape[2] != y:
378
+ x_offset = (pixels.shape[1] % 8) // 2
379
+ y_offset = (pixels.shape[2] % 8) // 2
380
+ pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
381
+ mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
382
+
383
+ m = (1.0 - mask.round()).squeeze(1)
384
+ for i in range(3):
385
+ pixels[:,:,:,i] -= 0.5
386
+ pixels[:,:,:,i] *= m
387
+ pixels[:,:,:,i] += 0.5
388
+ concat_latent = vae.encode(pixels)
389
+ orig_latent = vae.encode(orig_pixels)
390
+
391
+ out_latent = {}
392
+
393
+ out_latent["samples"] = orig_latent
394
+ out_latent["noise_mask"] = mask
395
+
396
+ out = []
397
+ for conditioning in [positive, negative]:
398
+ c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent,
399
+ "concat_mask": mask})
400
+ out.append(c)
401
+ return (out[0], out[1], out_latent)
402
+
403
+
404
+ class SaveLatent:
405
+ def __init__(self):
406
+ self.output_dir = folder_paths.get_output_directory()
407
+
408
+ @classmethod
409
+ def INPUT_TYPES(s):
410
+ return {"required": { "samples": ("LATENT", ),
411
+ "filename_prefix": ("STRING", {"default": "latents/totoroUI"})},
412
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
413
+ }
414
+ RETURN_TYPES = ()
415
+ FUNCTION = "save"
416
+
417
+ OUTPUT_NODE = True
418
+
419
+ CATEGORY = "_for_testing"
420
+
421
+ def save(self, samples, filename_prefix="totoroUI", prompt=None, extra_pnginfo=None):
422
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
423
+
424
+ # support save metadata for latent sharing
425
+ prompt_info = ""
426
+ if prompt is not None:
427
+ prompt_info = json.dumps(prompt)
428
+
429
+ metadata = None
430
+ if not args.disable_metadata:
431
+ metadata = {"prompt": prompt_info}
432
+ if extra_pnginfo is not None:
433
+ for x in extra_pnginfo:
434
+ metadata[x] = json.dumps(extra_pnginfo[x])
435
+
436
+ file = f"{filename}_{counter:05}_.latent"
437
+
438
+ results = list()
439
+ results.append({
440
+ "filename": file,
441
+ "subfolder": subfolder,
442
+ "type": "output"
443
+ })
444
+
445
+ file = os.path.join(full_output_folder, file)
446
+
447
+ output = {}
448
+ output["latent_tensor"] = samples["samples"]
449
+ output["latent_format_version_0"] = torch.tensor([])
450
+
451
+ totoro.utils.save_torch_file(output, file, metadata=metadata)
452
+ return { "ui": { "latents": results } }
453
+
454
+
455
+ class LoadLatent:
456
+ @classmethod
457
+ def INPUT_TYPES(s):
458
+ input_dir = folder_paths.get_input_directory()
459
+ files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
460
+ return {"required": {"latent": [sorted(files), ]}, }
461
+
462
+ CATEGORY = "_for_testing"
463
+
464
+ RETURN_TYPES = ("LATENT", )
465
+ FUNCTION = "load"
466
+
467
+ def load(self, latent):
468
+ latent_path = folder_paths.get_annotated_filepath(latent)
469
+ latent = safetensors.torch.load_file(latent_path, device="cpu")
470
+ multiplier = 1.0
471
+ if "latent_format_version_0" not in latent:
472
+ multiplier = 1.0 / 0.18215
473
+ samples = {"samples": latent["latent_tensor"].float() * multiplier}
474
+ return (samples, )
475
+
476
+ @classmethod
477
+ def IS_CHANGED(s, latent):
478
+ image_path = folder_paths.get_annotated_filepath(latent)
479
+ m = hashlib.sha256()
480
+ with open(image_path, 'rb') as f:
481
+ m.update(f.read())
482
+ return m.digest().hex()
483
+
484
+ @classmethod
485
+ def VALIDATE_INPUTS(s, latent):
486
+ if not folder_paths.exists_annotated_filepath(latent):
487
+ return "Invalid latent file: {}".format(latent)
488
+ return True
489
+
490
+
491
+ class CheckpointLoader:
492
+ @classmethod
493
+ def INPUT_TYPES(s):
494
+ return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ),
495
+ "ckpt_name": (folder_paths.get_filename_list("checkpoints"), )}}
496
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
497
+ FUNCTION = "load_checkpoint"
498
+
499
+ CATEGORY = "advanced/loaders"
500
+
501
+ def load_checkpoint(self, config_name, ckpt_name):
502
+ config_path = folder_paths.get_full_path("configs", config_name)
503
+ ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
504
+ return totoro.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
505
+
506
+ class CheckpointLoaderSimple:
507
+ @classmethod
508
+ def INPUT_TYPES(s):
509
+ return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
510
+ }}
511
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
512
+ FUNCTION = "load_checkpoint"
513
+
514
+ CATEGORY = "loaders"
515
+
516
+ def load_checkpoint(self, ckpt_name):
517
+ ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
518
+ out = totoro.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
519
+ return out[:3]
520
+
521
+ class DiffusersLoader:
522
+ @classmethod
523
+ def INPUT_TYPES(cls):
524
+ paths = []
525
+ for search_path in folder_paths.get_folder_paths("diffusers"):
526
+ if os.path.exists(search_path):
527
+ for root, subdir, files in os.walk(search_path, followlinks=True):
528
+ if "model_index.json" in files:
529
+ paths.append(os.path.relpath(root, start=search_path))
530
+
531
+ return {"required": {"model_path": (paths,), }}
532
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
533
+ FUNCTION = "load_checkpoint"
534
+
535
+ CATEGORY = "advanced/loaders/deprecated"
536
+
537
+ def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
538
+ for search_path in folder_paths.get_folder_paths("diffusers"):
539
+ if os.path.exists(search_path):
540
+ path = os.path.join(search_path, model_path)
541
+ if os.path.exists(path):
542
+ model_path = path
543
+ break
544
+
545
+ return totoro.diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
546
+
547
+
548
+ class unCLIPCheckpointLoader:
549
+ @classmethod
550
+ def INPUT_TYPES(s):
551
+ return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
552
+ }}
553
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION")
554
+ FUNCTION = "load_checkpoint"
555
+
556
+ CATEGORY = "loaders"
557
+
558
+ def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
559
+ ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
560
+ out = totoro.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
561
+ return out
562
+
563
+ class CLIPSetLastLayer:
564
+ @classmethod
565
+ def INPUT_TYPES(s):
566
+ return {"required": { "clip": ("CLIP", ),
567
+ "stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}),
568
+ }}
569
+ RETURN_TYPES = ("CLIP",)
570
+ FUNCTION = "set_last_layer"
571
+
572
+ CATEGORY = "conditioning"
573
+
574
+ def set_last_layer(self, clip, stop_at_clip_layer):
575
+ clip = clip.clone()
576
+ clip.clip_layer(stop_at_clip_layer)
577
+ return (clip,)
578
+
579
+ class LoraLoader:
580
+ def __init__(self):
581
+ self.loaded_lora = None
582
+
583
+ @classmethod
584
+ def INPUT_TYPES(s):
585
+ return {"required": { "model": ("MODEL",),
586
+ "clip": ("CLIP", ),
587
+ "lora_name": (folder_paths.get_filename_list("loras"), ),
588
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
589
+ "strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
590
+ }}
591
+ RETURN_TYPES = ("MODEL", "CLIP")
592
+ FUNCTION = "load_lora"
593
+
594
+ CATEGORY = "loaders"
595
+
596
+ def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
597
+ if strength_model == 0 and strength_clip == 0:
598
+ return (model, clip)
599
+
600
+ lora_path = folder_paths.get_full_path("loras", lora_name)
601
+ lora = None
602
+ if self.loaded_lora is not None:
603
+ if self.loaded_lora[0] == lora_path:
604
+ lora = self.loaded_lora[1]
605
+ else:
606
+ temp = self.loaded_lora
607
+ self.loaded_lora = None
608
+ del temp
609
+
610
+ if lora is None:
611
+ lora = totoro.utils.load_torch_file(lora_path, safe_load=True)
612
+ self.loaded_lora = (lora_path, lora)
613
+
614
+ model_lora, clip_lora = totoro.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
615
+ return (model_lora, clip_lora)
616
+
617
+ class LoraLoaderModelOnly(LoraLoader):
618
+ @classmethod
619
+ def INPUT_TYPES(s):
620
+ return {"required": { "model": ("MODEL",),
621
+ "lora_name": (folder_paths.get_filename_list("loras"), ),
622
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
623
+ }}
624
+ RETURN_TYPES = ("MODEL",)
625
+ FUNCTION = "load_lora_model_only"
626
+
627
+ def load_lora_model_only(self, model, lora_name, strength_model):
628
+ return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
629
+
630
+ class VAELoader:
631
+ @staticmethod
632
+ def vae_list():
633
+ vaes = folder_paths.get_filename_list("vae")
634
+ approx_vaes = folder_paths.get_filename_list("vae_approx")
635
+ sdxl_taesd_enc = False
636
+ sdxl_taesd_dec = False
637
+ sd1_taesd_enc = False
638
+ sd1_taesd_dec = False
639
+ sd3_taesd_enc = False
640
+ sd3_taesd_dec = False
641
+
642
+ for v in approx_vaes:
643
+ if v.startswith("taesd_decoder."):
644
+ sd1_taesd_dec = True
645
+ elif v.startswith("taesd_encoder."):
646
+ sd1_taesd_enc = True
647
+ elif v.startswith("taesdxl_decoder."):
648
+ sdxl_taesd_dec = True
649
+ elif v.startswith("taesdxl_encoder."):
650
+ sdxl_taesd_enc = True
651
+ elif v.startswith("taesd3_decoder."):
652
+ sd3_taesd_dec = True
653
+ elif v.startswith("taesd3_encoder."):
654
+ sd3_taesd_enc = True
655
+ if sd1_taesd_dec and sd1_taesd_enc:
656
+ vaes.append("taesd")
657
+ if sdxl_taesd_dec and sdxl_taesd_enc:
658
+ vaes.append("taesdxl")
659
+ if sd3_taesd_dec and sd3_taesd_enc:
660
+ vaes.append("taesd3")
661
+ return vaes
662
+
663
+ @staticmethod
664
+ def load_taesd(name):
665
+ sd = {}
666
+ approx_vaes = folder_paths.get_filename_list("vae_approx")
667
+
668
+ encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
669
+ decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
670
+
671
+ enc = totoro.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
672
+ for k in enc:
673
+ sd["taesd_encoder.{}".format(k)] = enc[k]
674
+
675
+ dec = totoro.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
676
+ for k in dec:
677
+ sd["taesd_decoder.{}".format(k)] = dec[k]
678
+
679
+ if name == "taesd":
680
+ sd["vae_scale"] = torch.tensor(0.18215)
681
+ sd["vae_shift"] = torch.tensor(0.0)
682
+ elif name == "taesdxl":
683
+ sd["vae_scale"] = torch.tensor(0.13025)
684
+ sd["vae_shift"] = torch.tensor(0.0)
685
+ elif name == "taesd3":
686
+ sd["vae_scale"] = torch.tensor(1.5305)
687
+ sd["vae_shift"] = torch.tensor(0.0609)
688
+ return sd
689
+
690
+ @classmethod
691
+ def INPUT_TYPES(s):
692
+ return {"required": { "vae_name": (s.vae_list(), )}}
693
+ RETURN_TYPES = ("VAE",)
694
+ FUNCTION = "load_vae"
695
+
696
+ CATEGORY = "loaders"
697
+
698
+ #TODO: scale factor?
699
+ def load_vae(self, vae_name):
700
+ if vae_name in ["taesd", "taesdxl", "taesd3"]:
701
+ sd = self.load_taesd(vae_name)
702
+ else:
703
+ vae_path = folder_paths.get_full_path("vae", vae_name)
704
+ sd = totoro.utils.load_torch_file(vae_path)
705
+ vae = totoro.sd.VAE(sd=sd)
706
+ return (vae,)
707
+
708
+ class ControlNetLoader:
709
+ @classmethod
710
+ def INPUT_TYPES(s):
711
+ return {"required": { "control_net_name": (folder_paths.get_filename_list("controlnet"), )}}
712
+
713
+ RETURN_TYPES = ("CONTROL_NET",)
714
+ FUNCTION = "load_controlnet"
715
+
716
+ CATEGORY = "loaders"
717
+
718
+ def load_controlnet(self, control_net_name):
719
+ controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
720
+ controlnet = totoro.controlnet.load_controlnet(controlnet_path)
721
+ return (controlnet,)
722
+
723
+ class DiffControlNetLoader:
724
+ @classmethod
725
+ def INPUT_TYPES(s):
726
+ return {"required": { "model": ("MODEL",),
727
+ "control_net_name": (folder_paths.get_filename_list("controlnet"), )}}
728
+
729
+ RETURN_TYPES = ("CONTROL_NET",)
730
+ FUNCTION = "load_controlnet"
731
+
732
+ CATEGORY = "loaders"
733
+
734
+ def load_controlnet(self, model, control_net_name):
735
+ controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
736
+ controlnet = totoro.controlnet.load_controlnet(controlnet_path, model)
737
+ return (controlnet,)
738
+
739
+
740
+ class ControlNetApply:
741
+ @classmethod
742
+ def INPUT_TYPES(s):
743
+ return {"required": {"conditioning": ("CONDITIONING", ),
744
+ "control_net": ("CONTROL_NET", ),
745
+ "image": ("IMAGE", ),
746
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01})
747
+ }}
748
+ RETURN_TYPES = ("CONDITIONING",)
749
+ FUNCTION = "apply_controlnet"
750
+
751
+ CATEGORY = "conditioning/controlnet"
752
+
753
+ def apply_controlnet(self, conditioning, control_net, image, strength):
754
+ if strength == 0:
755
+ return (conditioning, )
756
+
757
+ c = []
758
+ control_hint = image.movedim(-1,1)
759
+ for t in conditioning:
760
+ n = [t[0], t[1].copy()]
761
+ c_net = control_net.copy().set_cond_hint(control_hint, strength)
762
+ if 'control' in t[1]:
763
+ c_net.set_previous_controlnet(t[1]['control'])
764
+ n[1]['control'] = c_net
765
+ n[1]['control_apply_to_uncond'] = True
766
+ c.append(n)
767
+ return (c, )
768
+
769
+
770
+ class ControlNetApplyAdvanced:
771
+ @classmethod
772
+ def INPUT_TYPES(s):
773
+ return {"required": {"positive": ("CONDITIONING", ),
774
+ "negative": ("CONDITIONING", ),
775
+ "control_net": ("CONTROL_NET", ),
776
+ "image": ("IMAGE", ),
777
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
778
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
779
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
780
+ }}
781
+
782
+ RETURN_TYPES = ("CONDITIONING","CONDITIONING")
783
+ RETURN_NAMES = ("positive", "negative")
784
+ FUNCTION = "apply_controlnet"
785
+
786
+ CATEGORY = "conditioning/controlnet"
787
+
788
+ def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None):
789
+ if strength == 0:
790
+ return (positive, negative)
791
+
792
+ control_hint = image.movedim(-1,1)
793
+ cnets = {}
794
+
795
+ out = []
796
+ for conditioning in [positive, negative]:
797
+ c = []
798
+ for t in conditioning:
799
+ d = t[1].copy()
800
+
801
+ prev_cnet = d.get('control', None)
802
+ if prev_cnet in cnets:
803
+ c_net = cnets[prev_cnet]
804
+ else:
805
+ c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
806
+ c_net.set_previous_controlnet(prev_cnet)
807
+ cnets[prev_cnet] = c_net
808
+
809
+ d['control'] = c_net
810
+ d['control_apply_to_uncond'] = False
811
+ n = [t[0], d]
812
+ c.append(n)
813
+ out.append(c)
814
+ return (out[0], out[1])
815
+
816
+
817
+ class UNETLoader:
818
+ @classmethod
819
+ def INPUT_TYPES(s):
820
+ return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ),
821
+ "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
822
+ }}
823
+ RETURN_TYPES = ("MODEL",)
824
+ FUNCTION = "load_unet"
825
+
826
+ CATEGORY = "advanced/loaders"
827
+
828
+ def load_unet(self, unet_name, weight_dtype):
829
+ weight_dtype = {"default":None, "fp8_e4m3fn":torch.float8_e4m3fn, "fp8_e5m2":torch.float8_e4m3fn}[weight_dtype]
830
+ unet_path = folder_paths.get_full_path("unet", unet_name)
831
+ model = totoro.sd.load_unet(unet_path, dtype=weight_dtype)
832
+ return (model,)
833
+
834
+ class CLIPLoader:
835
+ @classmethod
836
+ def INPUT_TYPES(s):
837
+ return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
838
+ "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio"], ),
839
+ }}
840
+ RETURN_TYPES = ("CLIP",)
841
+ FUNCTION = "load_clip"
842
+
843
+ CATEGORY = "advanced/loaders"
844
+
845
+ def load_clip(self, clip_name, type="stable_diffusion"):
846
+ if type == "stable_cascade":
847
+ clip_type = totoro.sd.CLIPType.STABLE_CASCADE
848
+ elif type == "sd3":
849
+ clip_type = totoro.sd.CLIPType.SD3
850
+ elif type == "stable_audio":
851
+ clip_type = totoro.sd.CLIPType.STABLE_AUDIO
852
+ else:
853
+ clip_type = totoro.sd.CLIPType.STABLE_DIFFUSION
854
+
855
+ clip_path = folder_paths.get_full_path("clip", clip_name)
856
+ clip = totoro.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
857
+ return (clip,)
858
+
859
+ class DualCLIPLoader:
860
+ @classmethod
861
+ def INPUT_TYPES(s):
862
+ return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ),
863
+ "clip_name2": (folder_paths.get_filename_list("clip"), ),
864
+ "type": (["sdxl", "sd3", "flux"], ),
865
+ }}
866
+ RETURN_TYPES = ("CLIP",)
867
+ FUNCTION = "load_clip"
868
+
869
+ CATEGORY = "advanced/loaders"
870
+
871
+ def load_clip(self, clip_name1, clip_name2, type):
872
+ clip_path1 = folder_paths.get_full_path("clip", clip_name1)
873
+ clip_path2 = folder_paths.get_full_path("clip", clip_name2)
874
+ if type == "sdxl":
875
+ clip_type = totoro.sd.CLIPType.STABLE_DIFFUSION
876
+ elif type == "sd3":
877
+ clip_type = totoro.sd.CLIPType.SD3
878
+ elif type == "flux":
879
+ clip_type = totoro.sd.CLIPType.FLUX
880
+
881
+ clip = totoro.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
882
+ return (clip,)
883
+
884
+ class CLIPVisionLoader:
885
+ @classmethod
886
+ def INPUT_TYPES(s):
887
+ return {"required": { "clip_name": (folder_paths.get_filename_list("clip_vision"), ),
888
+ }}
889
+ RETURN_TYPES = ("CLIP_VISION",)
890
+ FUNCTION = "load_clip"
891
+
892
+ CATEGORY = "loaders"
893
+
894
+ def load_clip(self, clip_name):
895
+ clip_path = folder_paths.get_full_path("clip_vision", clip_name)
896
+ clip_vision = totoro.clip_vision.load(clip_path)
897
+ return (clip_vision,)
898
+
899
+ class CLIPVisionEncode:
900
+ @classmethod
901
+ def INPUT_TYPES(s):
902
+ return {"required": { "clip_vision": ("CLIP_VISION",),
903
+ "image": ("IMAGE",)
904
+ }}
905
+ RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
906
+ FUNCTION = "encode"
907
+
908
+ CATEGORY = "conditioning"
909
+
910
+ def encode(self, clip_vision, image):
911
+ output = clip_vision.encode_image(image)
912
+ return (output,)
913
+
914
+ class StyleModelLoader:
915
+ @classmethod
916
+ def INPUT_TYPES(s):
917
+ return {"required": { "style_model_name": (folder_paths.get_filename_list("style_models"), )}}
918
+
919
+ RETURN_TYPES = ("STYLE_MODEL",)
920
+ FUNCTION = "load_style_model"
921
+
922
+ CATEGORY = "loaders"
923
+
924
+ def load_style_model(self, style_model_name):
925
+ style_model_path = folder_paths.get_full_path("style_models", style_model_name)
926
+ style_model = totoro.sd.load_style_model(style_model_path)
927
+ return (style_model,)
928
+
929
+
930
+ class StyleModelApply:
931
+ @classmethod
932
+ def INPUT_TYPES(s):
933
+ return {"required": {"conditioning": ("CONDITIONING", ),
934
+ "style_model": ("STYLE_MODEL", ),
935
+ "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
936
+ }}
937
+ RETURN_TYPES = ("CONDITIONING",)
938
+ FUNCTION = "apply_stylemodel"
939
+
940
+ CATEGORY = "conditioning/style_model"
941
+
942
+ def apply_stylemodel(self, clip_vision_output, style_model, conditioning):
943
+ cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
944
+ c = []
945
+ for t in conditioning:
946
+ n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
947
+ c.append(n)
948
+ return (c, )
949
+
950
+ class unCLIPConditioning:
951
+ @classmethod
952
+ def INPUT_TYPES(s):
953
+ return {"required": {"conditioning": ("CONDITIONING", ),
954
+ "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
955
+ "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
956
+ "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
957
+ }}
958
+ RETURN_TYPES = ("CONDITIONING",)
959
+ FUNCTION = "apply_adm"
960
+
961
+ CATEGORY = "conditioning"
962
+
963
+ def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
964
+ if strength == 0:
965
+ return (conditioning, )
966
+
967
+ c = []
968
+ for t in conditioning:
969
+ o = t[1].copy()
970
+ x = {"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}
971
+ if "unclip_conditioning" in o:
972
+ o["unclip_conditioning"] = o["unclip_conditioning"][:] + [x]
973
+ else:
974
+ o["unclip_conditioning"] = [x]
975
+ n = [t[0], o]
976
+ c.append(n)
977
+ return (c, )
978
+
979
+ class GLIGENLoader:
980
+ @classmethod
981
+ def INPUT_TYPES(s):
982
+ return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"), )}}
983
+
984
+ RETURN_TYPES = ("GLIGEN",)
985
+ FUNCTION = "load_gligen"
986
+
987
+ CATEGORY = "loaders"
988
+
989
+ def load_gligen(self, gligen_name):
990
+ gligen_path = folder_paths.get_full_path("gligen", gligen_name)
991
+ gligen = totoro.sd.load_gligen(gligen_path)
992
+ return (gligen,)
993
+
994
+ class GLIGENTextBoxApply:
995
+ @classmethod
996
+ def INPUT_TYPES(s):
997
+ return {"required": {"conditioning_to": ("CONDITIONING", ),
998
+ "clip": ("CLIP", ),
999
+ "gligen_textbox_model": ("GLIGEN", ),
1000
+ "text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
1001
+ "width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
1002
+ "height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
1003
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1004
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1005
+ }}
1006
+ RETURN_TYPES = ("CONDITIONING",)
1007
+ FUNCTION = "append"
1008
+
1009
+ CATEGORY = "conditioning/gligen"
1010
+
1011
+ def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
1012
+ c = []
1013
+ cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled="unprojected")
1014
+ for t in conditioning_to:
1015
+ n = [t[0], t[1].copy()]
1016
+ position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]
1017
+ prev = []
1018
+ if "gligen" in n[1]:
1019
+ prev = n[1]['gligen'][2]
1020
+
1021
+ n[1]['gligen'] = ("position", gligen_textbox_model, prev + position_params)
1022
+ c.append(n)
1023
+ return (c, )
1024
+
1025
+ class EmptyLatentImage:
1026
+ def __init__(self):
1027
+ self.device = totoro.model_management.intermediate_device()
1028
+
1029
+ @classmethod
1030
+ def INPUT_TYPES(s):
1031
+ return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
1032
+ "height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
1033
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
1034
+ RETURN_TYPES = ("LATENT",)
1035
+ FUNCTION = "generate"
1036
+
1037
+ CATEGORY = "latent"
1038
+
1039
+ def generate(self, width, height, batch_size=1):
1040
+ latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
1041
+ return ({"samples":latent}, )
1042
+
1043
+
1044
+ class LatentFromBatch:
1045
+ @classmethod
1046
+ def INPUT_TYPES(s):
1047
+ return {"required": { "samples": ("LATENT",),
1048
+ "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
1049
+ "length": ("INT", {"default": 1, "min": 1, "max": 64}),
1050
+ }}
1051
+ RETURN_TYPES = ("LATENT",)
1052
+ FUNCTION = "frombatch"
1053
+
1054
+ CATEGORY = "latent/batch"
1055
+
1056
+ def frombatch(self, samples, batch_index, length):
1057
+ s = samples.copy()
1058
+ s_in = samples["samples"]
1059
+ batch_index = min(s_in.shape[0] - 1, batch_index)
1060
+ length = min(s_in.shape[0] - batch_index, length)
1061
+ s["samples"] = s_in[batch_index:batch_index + length].clone()
1062
+ if "noise_mask" in samples:
1063
+ masks = samples["noise_mask"]
1064
+ if masks.shape[0] == 1:
1065
+ s["noise_mask"] = masks.clone()
1066
+ else:
1067
+ if masks.shape[0] < s_in.shape[0]:
1068
+ masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
1069
+ s["noise_mask"] = masks[batch_index:batch_index + length].clone()
1070
+ if "batch_index" not in s:
1071
+ s["batch_index"] = [x for x in range(batch_index, batch_index+length)]
1072
+ else:
1073
+ s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
1074
+ return (s,)
1075
+
1076
+ class RepeatLatentBatch:
1077
+ @classmethod
1078
+ def INPUT_TYPES(s):
1079
+ return {"required": { "samples": ("LATENT",),
1080
+ "amount": ("INT", {"default": 1, "min": 1, "max": 64}),
1081
+ }}
1082
+ RETURN_TYPES = ("LATENT",)
1083
+ FUNCTION = "repeat"
1084
+
1085
+ CATEGORY = "latent/batch"
1086
+
1087
+ def repeat(self, samples, amount):
1088
+ s = samples.copy()
1089
+ s_in = samples["samples"]
1090
+
1091
+ s["samples"] = s_in.repeat((amount, 1,1,1))
1092
+ if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
1093
+ masks = samples["noise_mask"]
1094
+ if masks.shape[0] < s_in.shape[0]:
1095
+ masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
1096
+ s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
1097
+ if "batch_index" in s:
1098
+ offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
1099
+ s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
1100
+ return (s,)
1101
+
1102
+ class LatentUpscale:
1103
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
1104
+ crop_methods = ["disabled", "center"]
1105
+
1106
+ @classmethod
1107
+ def INPUT_TYPES(s):
1108
+ return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
1109
+ "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1110
+ "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1111
+ "crop": (s.crop_methods,)}}
1112
+ RETURN_TYPES = ("LATENT",)
1113
+ FUNCTION = "upscale"
1114
+
1115
+ CATEGORY = "latent"
1116
+
1117
+ def upscale(self, samples, upscale_method, width, height, crop):
1118
+ if width == 0 and height == 0:
1119
+ s = samples
1120
+ else:
1121
+ s = samples.copy()
1122
+
1123
+ if width == 0:
1124
+ height = max(64, height)
1125
+ width = max(64, round(samples["samples"].shape[3] * height / samples["samples"].shape[2]))
1126
+ elif height == 0:
1127
+ width = max(64, width)
1128
+ height = max(64, round(samples["samples"].shape[2] * width / samples["samples"].shape[3]))
1129
+ else:
1130
+ width = max(64, width)
1131
+ height = max(64, height)
1132
+
1133
+ s["samples"] = totoro.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
1134
+ return (s,)
1135
+
1136
+ class LatentUpscaleBy:
1137
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
1138
+
1139
+ @classmethod
1140
+ def INPUT_TYPES(s):
1141
+ return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
1142
+ "scale_by": ("FLOAT", {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}),}}
1143
+ RETURN_TYPES = ("LATENT",)
1144
+ FUNCTION = "upscale"
1145
+
1146
+ CATEGORY = "latent"
1147
+
1148
+ def upscale(self, samples, upscale_method, scale_by):
1149
+ s = samples.copy()
1150
+ width = round(samples["samples"].shape[3] * scale_by)
1151
+ height = round(samples["samples"].shape[2] * scale_by)
1152
+ s["samples"] = totoro.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled")
1153
+ return (s,)
1154
+
1155
+ class LatentRotate:
1156
+ @classmethod
1157
+ def INPUT_TYPES(s):
1158
+ return {"required": { "samples": ("LATENT",),
1159
+ "rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],),
1160
+ }}
1161
+ RETURN_TYPES = ("LATENT",)
1162
+ FUNCTION = "rotate"
1163
+
1164
+ CATEGORY = "latent/transform"
1165
+
1166
+ def rotate(self, samples, rotation):
1167
+ s = samples.copy()
1168
+ rotate_by = 0
1169
+ if rotation.startswith("90"):
1170
+ rotate_by = 1
1171
+ elif rotation.startswith("180"):
1172
+ rotate_by = 2
1173
+ elif rotation.startswith("270"):
1174
+ rotate_by = 3
1175
+
1176
+ s["samples"] = torch.rot90(samples["samples"], k=rotate_by, dims=[3, 2])
1177
+ return (s,)
1178
+
1179
+ class LatentFlip:
1180
+ @classmethod
1181
+ def INPUT_TYPES(s):
1182
+ return {"required": { "samples": ("LATENT",),
1183
+ "flip_method": (["x-axis: vertically", "y-axis: horizontally"],),
1184
+ }}
1185
+ RETURN_TYPES = ("LATENT",)
1186
+ FUNCTION = "flip"
1187
+
1188
+ CATEGORY = "latent/transform"
1189
+
1190
+ def flip(self, samples, flip_method):
1191
+ s = samples.copy()
1192
+ if flip_method.startswith("x"):
1193
+ s["samples"] = torch.flip(samples["samples"], dims=[2])
1194
+ elif flip_method.startswith("y"):
1195
+ s["samples"] = torch.flip(samples["samples"], dims=[3])
1196
+
1197
+ return (s,)
1198
+
1199
+ class LatentComposite:
1200
+ @classmethod
1201
+ def INPUT_TYPES(s):
1202
+ return {"required": { "samples_to": ("LATENT",),
1203
+ "samples_from": ("LATENT",),
1204
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1205
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1206
+ "feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1207
+ }}
1208
+ RETURN_TYPES = ("LATENT",)
1209
+ FUNCTION = "composite"
1210
+
1211
+ CATEGORY = "latent"
1212
+
1213
+ def composite(self, samples_to, samples_from, x, y, composite_method="normal", feather=0):
1214
+ x = x // 8
1215
+ y = y // 8
1216
+ feather = feather // 8
1217
+ samples_out = samples_to.copy()
1218
+ s = samples_to["samples"].clone()
1219
+ samples_to = samples_to["samples"]
1220
+ samples_from = samples_from["samples"]
1221
+ if feather == 0:
1222
+ s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
1223
+ else:
1224
+ samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
1225
+ mask = torch.ones_like(samples_from)
1226
+ for t in range(feather):
1227
+ if y != 0:
1228
+ mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
1229
+
1230
+ if y + samples_from.shape[2] < samples_to.shape[2]:
1231
+ mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
1232
+ if x != 0:
1233
+ mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
1234
+ if x + samples_from.shape[3] < samples_to.shape[3]:
1235
+ mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
1236
+ rev_mask = torch.ones_like(mask) - mask
1237
+ s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask
1238
+ samples_out["samples"] = s
1239
+ return (samples_out,)
1240
+
1241
+ class LatentBlend:
1242
+ @classmethod
1243
+ def INPUT_TYPES(s):
1244
+ return {"required": {
1245
+ "samples1": ("LATENT",),
1246
+ "samples2": ("LATENT",),
1247
+ "blend_factor": ("FLOAT", {
1248
+ "default": 0.5,
1249
+ "min": 0,
1250
+ "max": 1,
1251
+ "step": 0.01
1252
+ }),
1253
+ }}
1254
+
1255
+ RETURN_TYPES = ("LATENT",)
1256
+ FUNCTION = "blend"
1257
+
1258
+ CATEGORY = "_for_testing"
1259
+
1260
+ def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"):
1261
+
1262
+ samples_out = samples1.copy()
1263
+ samples1 = samples1["samples"]
1264
+ samples2 = samples2["samples"]
1265
+
1266
+ if samples1.shape != samples2.shape:
1267
+ samples2.permute(0, 3, 1, 2)
1268
+ samples2 = totoro.utils.common_upscale(samples2, samples1.shape[3], samples1.shape[2], 'bicubic', crop='center')
1269
+ samples2.permute(0, 2, 3, 1)
1270
+
1271
+ samples_blended = self.blend_mode(samples1, samples2, blend_mode)
1272
+ samples_blended = samples1 * blend_factor + samples_blended * (1 - blend_factor)
1273
+ samples_out["samples"] = samples_blended
1274
+ return (samples_out,)
1275
+
1276
+ def blend_mode(self, img1, img2, mode):
1277
+ if mode == "normal":
1278
+ return img2
1279
+ else:
1280
+ raise ValueError(f"Unsupported blend mode: {mode}")
1281
+
1282
+ class LatentCrop:
1283
+ @classmethod
1284
+ def INPUT_TYPES(s):
1285
+ return {"required": { "samples": ("LATENT",),
1286
+ "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
1287
+ "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
1288
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1289
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1290
+ }}
1291
+ RETURN_TYPES = ("LATENT",)
1292
+ FUNCTION = "crop"
1293
+
1294
+ CATEGORY = "latent/transform"
1295
+
1296
+ def crop(self, samples, width, height, x, y):
1297
+ s = samples.copy()
1298
+ samples = samples['samples']
1299
+ x = x // 8
1300
+ y = y // 8
1301
+
1302
+ #enfonce minimum size of 64
1303
+ if x > (samples.shape[3] - 8):
1304
+ x = samples.shape[3] - 8
1305
+ if y > (samples.shape[2] - 8):
1306
+ y = samples.shape[2] - 8
1307
+
1308
+ new_height = height // 8
1309
+ new_width = width // 8
1310
+ to_x = new_width + x
1311
+ to_y = new_height + y
1312
+ s['samples'] = samples[:,:,y:to_y, x:to_x]
1313
+ return (s,)
1314
+
1315
+ class SetLatentNoiseMask:
1316
+ @classmethod
1317
+ def INPUT_TYPES(s):
1318
+ return {"required": { "samples": ("LATENT",),
1319
+ "mask": ("MASK",),
1320
+ }}
1321
+ RETURN_TYPES = ("LATENT",)
1322
+ FUNCTION = "set_mask"
1323
+
1324
+ CATEGORY = "latent/inpaint"
1325
+
1326
+ def set_mask(self, samples, mask):
1327
+ s = samples.copy()
1328
+ s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
1329
+ return (s,)
1330
+
1331
+ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
1332
+ latent_image = latent["samples"]
1333
+ latent_image = totoro.sample.fix_empty_latent_channels(model, latent_image)
1334
+
1335
+ if disable_noise:
1336
+ noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
1337
+ else:
1338
+ batch_inds = latent["batch_index"] if "batch_index" in latent else None
1339
+ noise = totoro.sample.prepare_noise(latent_image, seed, batch_inds)
1340
+
1341
+ noise_mask = None
1342
+ if "noise_mask" in latent:
1343
+ noise_mask = latent["noise_mask"]
1344
+
1345
+ callback = latent_preview.prepare_callback(model, steps)
1346
+ disable_pbar = not totoro.utils.PROGRESS_BAR_ENABLED
1347
+ samples = totoro.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
1348
+ denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
1349
+ force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
1350
+ out = latent.copy()
1351
+ out["samples"] = samples
1352
+ return (out, )
1353
+
1354
+ class KSampler:
1355
+ @classmethod
1356
+ def INPUT_TYPES(s):
1357
+ return {"required":
1358
+ {"model": ("MODEL",),
1359
+ "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
1360
+ "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
1361
+ "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
1362
+ "sampler_name": (totoro.samplers.KSampler.SAMPLERS, ),
1363
+ "scheduler": (totoro.samplers.KSampler.SCHEDULERS, ),
1364
+ "positive": ("CONDITIONING", ),
1365
+ "negative": ("CONDITIONING", ),
1366
+ "latent_image": ("LATENT", ),
1367
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
1368
+ }
1369
+ }
1370
+
1371
+ RETURN_TYPES = ("LATENT",)
1372
+ FUNCTION = "sample"
1373
+
1374
+ CATEGORY = "sampling"
1375
+
1376
+ def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
1377
+ return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
1378
+
1379
+ class KSamplerAdvanced:
1380
+ @classmethod
1381
+ def INPUT_TYPES(s):
1382
+ return {"required":
1383
+ {"model": ("MODEL",),
1384
+ "add_noise": (["enable", "disable"], ),
1385
+ "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
1386
+ "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
1387
+ "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
1388
+ "sampler_name": (totoro.samplers.KSampler.SAMPLERS, ),
1389
+ "scheduler": (totoro.samplers.KSampler.SCHEDULERS, ),
1390
+ "positive": ("CONDITIONING", ),
1391
+ "negative": ("CONDITIONING", ),
1392
+ "latent_image": ("LATENT", ),
1393
+ "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
1394
+ "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
1395
+ "return_with_leftover_noise": (["disable", "enable"], ),
1396
+ }
1397
+ }
1398
+
1399
+ RETURN_TYPES = ("LATENT",)
1400
+ FUNCTION = "sample"
1401
+
1402
+ CATEGORY = "sampling"
1403
+
1404
+ def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0):
1405
+ force_full_denoise = True
1406
+ if return_with_leftover_noise == "enable":
1407
+ force_full_denoise = False
1408
+ disable_noise = False
1409
+ if add_noise == "disable":
1410
+ disable_noise = True
1411
+ return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
1412
+
1413
+ class SaveImage:
1414
+ def __init__(self):
1415
+ self.output_dir = folder_paths.get_output_directory()
1416
+ self.type = "output"
1417
+ self.prefix_append = ""
1418
+ self.compress_level = 4
1419
+
1420
+ @classmethod
1421
+ def INPUT_TYPES(s):
1422
+ return {"required":
1423
+ {"images": ("IMAGE", ),
1424
+ "filename_prefix": ("STRING", {"default": "totoroUI"})},
1425
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
1426
+ }
1427
+
1428
+ RETURN_TYPES = ()
1429
+ FUNCTION = "save_images"
1430
+
1431
+ OUTPUT_NODE = True
1432
+
1433
+ CATEGORY = "image"
1434
+
1435
+ def save_images(self, images, filename_prefix="totoroUI", prompt=None, extra_pnginfo=None):
1436
+ filename_prefix += self.prefix_append
1437
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
1438
+ results = list()
1439
+ for (batch_number, image) in enumerate(images):
1440
+ i = 255. * image.cpu().numpy()
1441
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
1442
+ metadata = None
1443
+ if not args.disable_metadata:
1444
+ metadata = PngInfo()
1445
+ if prompt is not None:
1446
+ metadata.add_text("prompt", json.dumps(prompt))
1447
+ if extra_pnginfo is not None:
1448
+ for x in extra_pnginfo:
1449
+ metadata.add_text(x, json.dumps(extra_pnginfo[x]))
1450
+
1451
+ filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
1452
+ file = f"{filename_with_batch_num}_{counter:05}_.png"
1453
+ img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
1454
+ results.append({
1455
+ "filename": file,
1456
+ "subfolder": subfolder,
1457
+ "type": self.type
1458
+ })
1459
+ counter += 1
1460
+
1461
+ return { "ui": { "images": results } }
1462
+
1463
+ class PreviewImage(SaveImage):
1464
+ def __init__(self):
1465
+ self.output_dir = folder_paths.get_temp_directory()
1466
+ self.type = "temp"
1467
+ self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
1468
+ self.compress_level = 1
1469
+
1470
+ @classmethod
1471
+ def INPUT_TYPES(s):
1472
+ return {"required":
1473
+ {"images": ("IMAGE", ), },
1474
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
1475
+ }
1476
+
1477
+ class LoadImage:
1478
+ @classmethod
1479
+ def INPUT_TYPES(s):
1480
+ input_dir = folder_paths.get_input_directory()
1481
+ files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
1482
+ return {"required":
1483
+ {"image": (sorted(files), {"image_upload": True})},
1484
+ }
1485
+
1486
+ CATEGORY = "image"
1487
+
1488
+ RETURN_TYPES = ("IMAGE", "MASK")
1489
+ FUNCTION = "load_image"
1490
+ def load_image(self, image):
1491
+ image_path = folder_paths.get_annotated_filepath(image)
1492
+
1493
+ img = node_helpers.pillow(Image.open, image_path)
1494
+
1495
+ output_images = []
1496
+ output_masks = []
1497
+ w, h = None, None
1498
+
1499
+ excluded_formats = ['MPO']
1500
+
1501
+ for i in ImageSequence.Iterator(img):
1502
+ i = node_helpers.pillow(ImageOps.exif_transpose, i)
1503
+
1504
+ if i.mode == 'I':
1505
+ i = i.point(lambda i: i * (1 / 255))
1506
+ image = i.convert("RGB")
1507
+
1508
+ if len(output_images) == 0:
1509
+ w = image.size[0]
1510
+ h = image.size[1]
1511
+
1512
+ if image.size[0] != w or image.size[1] != h:
1513
+ continue
1514
+
1515
+ image = np.array(image).astype(np.float32) / 255.0
1516
+ image = torch.from_numpy(image)[None,]
1517
+ if 'A' in i.getbands():
1518
+ mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
1519
+ mask = 1. - torch.from_numpy(mask)
1520
+ else:
1521
+ mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
1522
+ output_images.append(image)
1523
+ output_masks.append(mask.unsqueeze(0))
1524
+
1525
+ if len(output_images) > 1 and img.format not in excluded_formats:
1526
+ output_image = torch.cat(output_images, dim=0)
1527
+ output_mask = torch.cat(output_masks, dim=0)
1528
+ else:
1529
+ output_image = output_images[0]
1530
+ output_mask = output_masks[0]
1531
+
1532
+ return (output_image, output_mask)
1533
+
1534
+ @classmethod
1535
+ def IS_CHANGED(s, image):
1536
+ image_path = folder_paths.get_annotated_filepath(image)
1537
+ m = hashlib.sha256()
1538
+ with open(image_path, 'rb') as f:
1539
+ m.update(f.read())
1540
+ return m.digest().hex()
1541
+
1542
+ @classmethod
1543
+ def VALIDATE_INPUTS(s, image):
1544
+ if not folder_paths.exists_annotated_filepath(image):
1545
+ return "Invalid image file: {}".format(image)
1546
+
1547
+ return True
1548
+
1549
+ class LoadImageMask:
1550
+ _color_channels = ["alpha", "red", "green", "blue"]
1551
+ @classmethod
1552
+ def INPUT_TYPES(s):
1553
+ input_dir = folder_paths.get_input_directory()
1554
+ files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
1555
+ return {"required":
1556
+ {"image": (sorted(files), {"image_upload": True}),
1557
+ "channel": (s._color_channels, ), }
1558
+ }
1559
+
1560
+ CATEGORY = "mask"
1561
+
1562
+ RETURN_TYPES = ("MASK",)
1563
+ FUNCTION = "load_image"
1564
+ def load_image(self, image, channel):
1565
+ image_path = folder_paths.get_annotated_filepath(image)
1566
+ i = node_helpers.pillow(Image.open, image_path)
1567
+ i = node_helpers.pillow(ImageOps.exif_transpose, i)
1568
+ if i.getbands() != ("R", "G", "B", "A"):
1569
+ if i.mode == 'I':
1570
+ i = i.point(lambda i: i * (1 / 255))
1571
+ i = i.convert("RGBA")
1572
+ mask = None
1573
+ c = channel[0].upper()
1574
+ if c in i.getbands():
1575
+ mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
1576
+ mask = torch.from_numpy(mask)
1577
+ if c == 'A':
1578
+ mask = 1. - mask
1579
+ else:
1580
+ mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
1581
+ return (mask.unsqueeze(0),)
1582
+
1583
+ @classmethod
1584
+ def IS_CHANGED(s, image, channel):
1585
+ image_path = folder_paths.get_annotated_filepath(image)
1586
+ m = hashlib.sha256()
1587
+ with open(image_path, 'rb') as f:
1588
+ m.update(f.read())
1589
+ return m.digest().hex()
1590
+
1591
+ @classmethod
1592
+ def VALIDATE_INPUTS(s, image):
1593
+ if not folder_paths.exists_annotated_filepath(image):
1594
+ return "Invalid image file: {}".format(image)
1595
+
1596
+ return True
1597
+
1598
+ class ImageScale:
1599
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
1600
+ crop_methods = ["disabled", "center"]
1601
+
1602
+ @classmethod
1603
+ def INPUT_TYPES(s):
1604
+ return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
1605
+ "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
1606
+ "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
1607
+ "crop": (s.crop_methods,)}}
1608
+ RETURN_TYPES = ("IMAGE",)
1609
+ FUNCTION = "upscale"
1610
+
1611
+ CATEGORY = "image/upscaling"
1612
+
1613
+ def upscale(self, image, upscale_method, width, height, crop):
1614
+ if width == 0 and height == 0:
1615
+ s = image
1616
+ else:
1617
+ samples = image.movedim(-1,1)
1618
+
1619
+ if width == 0:
1620
+ width = max(1, round(samples.shape[3] * height / samples.shape[2]))
1621
+ elif height == 0:
1622
+ height = max(1, round(samples.shape[2] * width / samples.shape[3]))
1623
+
1624
+ s = totoro.utils.common_upscale(samples, width, height, upscale_method, crop)
1625
+ s = s.movedim(1,-1)
1626
+ return (s,)
1627
+
1628
+ class ImageScaleBy:
1629
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
1630
+
1631
+ @classmethod
1632
+ def INPUT_TYPES(s):
1633
+ return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
1634
+ "scale_by": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 8.0, "step": 0.01}),}}
1635
+ RETURN_TYPES = ("IMAGE",)
1636
+ FUNCTION = "upscale"
1637
+
1638
+ CATEGORY = "image/upscaling"
1639
+
1640
+ def upscale(self, image, upscale_method, scale_by):
1641
+ samples = image.movedim(-1,1)
1642
+ width = round(samples.shape[3] * scale_by)
1643
+ height = round(samples.shape[2] * scale_by)
1644
+ s = totoro.utils.common_upscale(samples, width, height, upscale_method, "disabled")
1645
+ s = s.movedim(1,-1)
1646
+ return (s,)
1647
+
1648
+ class ImageInvert:
1649
+
1650
+ @classmethod
1651
+ def INPUT_TYPES(s):
1652
+ return {"required": { "image": ("IMAGE",)}}
1653
+
1654
+ RETURN_TYPES = ("IMAGE",)
1655
+ FUNCTION = "invert"
1656
+
1657
+ CATEGORY = "image"
1658
+
1659
+ def invert(self, image):
1660
+ s = 1.0 - image
1661
+ return (s,)
1662
+
1663
+ class ImageBatch:
1664
+
1665
+ @classmethod
1666
+ def INPUT_TYPES(s):
1667
+ return {"required": { "image1": ("IMAGE",), "image2": ("IMAGE",)}}
1668
+
1669
+ RETURN_TYPES = ("IMAGE",)
1670
+ FUNCTION = "batch"
1671
+
1672
+ CATEGORY = "image"
1673
+
1674
+ def batch(self, image1, image2):
1675
+ if image1.shape[1:] != image2.shape[1:]:
1676
+ image2 = totoro.utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1)
1677
+ s = torch.cat((image1, image2), dim=0)
1678
+ return (s,)
1679
+
1680
+ class EmptyImage:
1681
+ def __init__(self, device="cpu"):
1682
+ self.device = device
1683
+
1684
+ @classmethod
1685
+ def INPUT_TYPES(s):
1686
+ return {"required": { "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
1687
+ "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
1688
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
1689
+ "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
1690
+ }}
1691
+ RETURN_TYPES = ("IMAGE",)
1692
+ FUNCTION = "generate"
1693
+
1694
+ CATEGORY = "image"
1695
+
1696
+ def generate(self, width, height, batch_size=1, color=0):
1697
+ r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF)
1698
+ g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF)
1699
+ b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF)
1700
+ return (torch.cat((r, g, b), dim=-1), )
1701
+
1702
+ class ImagePadForOutpaint:
1703
+
1704
+ @classmethod
1705
+ def INPUT_TYPES(s):
1706
+ return {
1707
+ "required": {
1708
+ "image": ("IMAGE",),
1709
+ "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1710
+ "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1711
+ "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1712
+ "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1713
+ "feathering": ("INT", {"default": 40, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
1714
+ }
1715
+ }
1716
+
1717
+ RETURN_TYPES = ("IMAGE", "MASK")
1718
+ FUNCTION = "expand_image"
1719
+
1720
+ CATEGORY = "image"
1721
+
1722
+ def expand_image(self, image, left, top, right, bottom, feathering):
1723
+ d1, d2, d3, d4 = image.size()
1724
+
1725
+ new_image = torch.ones(
1726
+ (d1, d2 + top + bottom, d3 + left + right, d4),
1727
+ dtype=torch.float32,
1728
+ ) * 0.5
1729
+
1730
+ new_image[:, top:top + d2, left:left + d3, :] = image
1731
+
1732
+ mask = torch.ones(
1733
+ (d2 + top + bottom, d3 + left + right),
1734
+ dtype=torch.float32,
1735
+ )
1736
+
1737
+ t = torch.zeros(
1738
+ (d2, d3),
1739
+ dtype=torch.float32
1740
+ )
1741
+
1742
+ if feathering > 0 and feathering * 2 < d2 and feathering * 2 < d3:
1743
+
1744
+ for i in range(d2):
1745
+ for j in range(d3):
1746
+ dt = i if top != 0 else d2
1747
+ db = d2 - i if bottom != 0 else d2
1748
+
1749
+ dl = j if left != 0 else d3
1750
+ dr = d3 - j if right != 0 else d3
1751
+
1752
+ d = min(dt, db, dl, dr)
1753
+
1754
+ if d >= feathering:
1755
+ continue
1756
+
1757
+ v = (feathering - d) / feathering
1758
+
1759
+ t[i, j] = v * v
1760
+
1761
+ mask[top:top + d2, left:left + d3] = t
1762
+
1763
+ return (new_image, mask)
1764
+
1765
+
1766
+ NODE_CLASS_MAPPINGS = {
1767
+ "KSampler": KSampler,
1768
+ "CheckpointLoaderSimple": CheckpointLoaderSimple,
1769
+ "CLIPTextEncode": CLIPTextEncode,
1770
+ "CLIPSetLastLayer": CLIPSetLastLayer,
1771
+ "VAEDecode": VAEDecode,
1772
+ "VAEEncode": VAEEncode,
1773
+ "VAEEncodeForInpaint": VAEEncodeForInpaint,
1774
+ "VAELoader": VAELoader,
1775
+ "EmptyLatentImage": EmptyLatentImage,
1776
+ "LatentUpscale": LatentUpscale,
1777
+ "LatentUpscaleBy": LatentUpscaleBy,
1778
+ "LatentFromBatch": LatentFromBatch,
1779
+ "RepeatLatentBatch": RepeatLatentBatch,
1780
+ "SaveImage": SaveImage,
1781
+ "PreviewImage": PreviewImage,
1782
+ "LoadImage": LoadImage,
1783
+ "LoadImageMask": LoadImageMask,
1784
+ "ImageScale": ImageScale,
1785
+ "ImageScaleBy": ImageScaleBy,
1786
+ "ImageInvert": ImageInvert,
1787
+ "ImageBatch": ImageBatch,
1788
+ "ImagePadForOutpaint": ImagePadForOutpaint,
1789
+ "EmptyImage": EmptyImage,
1790
+ "ConditioningAverage": ConditioningAverage ,
1791
+ "ConditioningCombine": ConditioningCombine,
1792
+ "ConditioningConcat": ConditioningConcat,
1793
+ "ConditioningSetArea": ConditioningSetArea,
1794
+ "ConditioningSetAreaPercentage": ConditioningSetAreaPercentage,
1795
+ "ConditioningSetAreaStrength": ConditioningSetAreaStrength,
1796
+ "ConditioningSetMask": ConditioningSetMask,
1797
+ "KSamplerAdvanced": KSamplerAdvanced,
1798
+ "SetLatentNoiseMask": SetLatentNoiseMask,
1799
+ "LatentComposite": LatentComposite,
1800
+ "LatentBlend": LatentBlend,
1801
+ "LatentRotate": LatentRotate,
1802
+ "LatentFlip": LatentFlip,
1803
+ "LatentCrop": LatentCrop,
1804
+ "LoraLoader": LoraLoader,
1805
+ "CLIPLoader": CLIPLoader,
1806
+ "UNETLoader": UNETLoader,
1807
+ "DualCLIPLoader": DualCLIPLoader,
1808
+ "CLIPVisionEncode": CLIPVisionEncode,
1809
+ "StyleModelApply": StyleModelApply,
1810
+ "unCLIPConditioning": unCLIPConditioning,
1811
+ "ControlNetApply": ControlNetApply,
1812
+ "ControlNetApplyAdvanced": ControlNetApplyAdvanced,
1813
+ "ControlNetLoader": ControlNetLoader,
1814
+ "DiffControlNetLoader": DiffControlNetLoader,
1815
+ "StyleModelLoader": StyleModelLoader,
1816
+ "CLIPVisionLoader": CLIPVisionLoader,
1817
+ "VAEDecodeTiled": VAEDecodeTiled,
1818
+ "VAEEncodeTiled": VAEEncodeTiled,
1819
+ "unCLIPCheckpointLoader": unCLIPCheckpointLoader,
1820
+ "GLIGENLoader": GLIGENLoader,
1821
+ "GLIGENTextBoxApply": GLIGENTextBoxApply,
1822
+ "InpaintModelConditioning": InpaintModelConditioning,
1823
+
1824
+ "CheckpointLoader": CheckpointLoader,
1825
+ "DiffusersLoader": DiffusersLoader,
1826
+
1827
+ "LoadLatent": LoadLatent,
1828
+ "SaveLatent": SaveLatent,
1829
+
1830
+ "ConditioningZeroOut": ConditioningZeroOut,
1831
+ "ConditioningSetTimestepRange": ConditioningSetTimestepRange,
1832
+ "LoraLoaderModelOnly": LoraLoaderModelOnly,
1833
+ }
1834
+
1835
+ NODE_DISPLAY_NAME_MAPPINGS = {
1836
+ # Sampling
1837
+ "KSampler": "KSampler",
1838
+ "KSamplerAdvanced": "KSampler (Advanced)",
1839
+ # Loaders
1840
+ "CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
1841
+ "CheckpointLoaderSimple": "Load Checkpoint",
1842
+ "VAELoader": "Load VAE",
1843
+ "LoraLoader": "Load LoRA",
1844
+ "CLIPLoader": "Load CLIP",
1845
+ "ControlNetLoader": "Load ControlNet Model",
1846
+ "DiffControlNetLoader": "Load ControlNet Model (diff)",
1847
+ "StyleModelLoader": "Load Style Model",
1848
+ "CLIPVisionLoader": "Load CLIP Vision",
1849
+ "UpscaleModelLoader": "Load Upscale Model",
1850
+ "UNETLoader": "Load Diffusion Model",
1851
+ # Conditioning
1852
+ "CLIPVisionEncode": "CLIP Vision Encode",
1853
+ "StyleModelApply": "Apply Style Model",
1854
+ "CLIPTextEncode": "CLIP Text Encode (Prompt)",
1855
+ "CLIPSetLastLayer": "CLIP Set Last Layer",
1856
+ "ConditioningCombine": "Conditioning (Combine)",
1857
+ "ConditioningAverage ": "Conditioning (Average)",
1858
+ "ConditioningConcat": "Conditioning (Concat)",
1859
+ "ConditioningSetArea": "Conditioning (Set Area)",
1860
+ "ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
1861
+ "ConditioningSetMask": "Conditioning (Set Mask)",
1862
+ "ControlNetApply": "Apply ControlNet",
1863
+ "ControlNetApplyAdvanced": "Apply ControlNet (Advanced)",
1864
+ # Latent
1865
+ "VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
1866
+ "SetLatentNoiseMask": "Set Latent Noise Mask",
1867
+ "VAEDecode": "VAE Decode",
1868
+ "VAEEncode": "VAE Encode",
1869
+ "LatentRotate": "Rotate Latent",
1870
+ "LatentFlip": "Flip Latent",
1871
+ "LatentCrop": "Crop Latent",
1872
+ "EmptyLatentImage": "Empty Latent Image",
1873
+ "LatentUpscale": "Upscale Latent",
1874
+ "LatentUpscaleBy": "Upscale Latent By",
1875
+ "LatentComposite": "Latent Composite",
1876
+ "LatentBlend": "Latent Blend",
1877
+ "LatentFromBatch" : "Latent From Batch",
1878
+ "RepeatLatentBatch": "Repeat Latent Batch",
1879
+ # Image
1880
+ "SaveImage": "Save Image",
1881
+ "PreviewImage": "Preview Image",
1882
+ "LoadImage": "Load Image",
1883
+ "LoadImageMask": "Load Image (as Mask)",
1884
+ "ImageScale": "Upscale Image",
1885
+ "ImageScaleBy": "Upscale Image By",
1886
+ "ImageUpscaleWithModel": "Upscale Image (using Model)",
1887
+ "ImageInvert": "Invert Image",
1888
+ "ImagePadForOutpaint": "Pad Image for Outpainting",
1889
+ "ImageBatch": "Batch Images",
1890
+ # _for_testing
1891
+ "VAEDecodeTiled": "VAE Decode (Tiled)",
1892
+ "VAEEncodeTiled": "VAE Encode (Tiled)",
1893
+ }
1894
+
1895
+ EXTENSION_WEB_DIRS = {}
1896
+
1897
+
1898
+ def get_module_name(module_path: str) -> str:
1899
+ """
1900
+ Returns the module name based on the given module path.
1901
+ Examples:
1902
+ get_module_name("C:/Users/username/totoroUI/custom_nodes/my_custom_node.py") -> "my_custom_node"
1903
+ get_module_name("C:/Users/username/totoroUI/custom_nodes/my_custom_node") -> "my_custom_node"
1904
+ get_module_name("C:/Users/username/totoroUI/custom_nodes/my_custom_node/") -> "my_custom_node"
1905
+ get_module_name("C:/Users/username/totoroUI/custom_nodes/my_custom_node/__init__.py") -> "my_custom_node"
1906
+ get_module_name("C:/Users/username/totoroUI/custom_nodes/my_custom_node/__init__") -> "my_custom_node"
1907
+ get_module_name("C:/Users/username/totoroUI/custom_nodes/my_custom_node/__init__/") -> "my_custom_node"
1908
+ get_module_name("C:/Users/username/totoroUI/custom_nodes/my_custom_node.disabled") -> "custom_nodes
1909
+ Args:
1910
+ module_path (str): The path of the module.
1911
+ Returns:
1912
+ str: The module name.
1913
+ """
1914
+ base_path = os.path.basename(module_path)
1915
+ if os.path.isfile(module_path):
1916
+ base_path = os.path.splitext(base_path)[0]
1917
+ return base_path
1918
+
1919
+
1920
+ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
1921
+ module_name = os.path.basename(module_path)
1922
+ if os.path.isfile(module_path):
1923
+ sp = os.path.splitext(module_path)
1924
+ module_name = sp[0]
1925
+ try:
1926
+ logging.debug("Trying to load custom node {}".format(module_path))
1927
+ if os.path.isfile(module_path):
1928
+ module_spec = importlib.util.spec_from_file_location(module_name, module_path)
1929
+ module_dir = os.path.split(module_path)[0]
1930
+ else:
1931
+ module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
1932
+ module_dir = module_path
1933
+
1934
+ module = importlib.util.module_from_spec(module_spec)
1935
+ sys.modules[module_name] = module
1936
+ module_spec.loader.exec_module(module)
1937
+
1938
+ if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None:
1939
+ web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY")))
1940
+ if os.path.isdir(web_dir):
1941
+ EXTENSION_WEB_DIRS[module_name] = web_dir
1942
+
1943
+ if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
1944
+ for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
1945
+ if name not in ignore:
1946
+ NODE_CLASS_MAPPINGS[name] = node_cls
1947
+ node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
1948
+ if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
1949
+ NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
1950
+ return True
1951
+ else:
1952
+ logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
1953
+ return False
1954
+ except Exception as e:
1955
+ logging.warning(traceback.format_exc())
1956
+ logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
1957
+ return False
1958
+
1959
+ def init_external_custom_nodes():
1960
+ """
1961
+ Initializes the external custom nodes.
1962
+
1963
+ This function loads custom nodes from the specified folder paths and imports them into the application.
1964
+ It measures the import times for each custom node and logs the results.
1965
+
1966
+ Returns:
1967
+ None
1968
+ """
1969
+ base_node_names = set(NODE_CLASS_MAPPINGS.keys())
1970
+ node_paths = folder_paths.get_folder_paths("custom_nodes")
1971
+ node_import_times = []
1972
+ for custom_node_path in node_paths:
1973
+ possible_modules = os.listdir(os.path.realpath(custom_node_path))
1974
+ if "__pycache__" in possible_modules:
1975
+ possible_modules.remove("__pycache__")
1976
+
1977
+ for possible_module in possible_modules:
1978
+ module_path = os.path.join(custom_node_path, possible_module)
1979
+ if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
1980
+ if module_path.endswith(".disabled"): continue
1981
+ time_before = time.perf_counter()
1982
+ success = load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
1983
+ node_import_times.append((time.perf_counter() - time_before, module_path, success))
1984
+
1985
+ if len(node_import_times) > 0:
1986
+ logging.info("\nImport times for custom nodes:")
1987
+ for n in sorted(node_import_times):
1988
+ if n[2]:
1989
+ import_message = ""
1990
+ else:
1991
+ import_message = " (IMPORT FAILED)"
1992
+ logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
1993
+ logging.info("")
1994
+
1995
+ def init_builtin_extra_nodes():
1996
+ """
1997
+ Initializes the built-in extra nodes in totoroUI.
1998
+
1999
+ This function loads the extra node files located in the "totoro_extras" directory and imports them into totoroUI.
2000
+ If any of the extra node files fail to import, a warning message is logged.
2001
+
2002
+ Returns:
2003
+ None
2004
+ """
2005
+ extras_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "totoro_extras")
2006
+ extras_files = [
2007
+ "nodes_latent.py",
2008
+ "nodes_hypernetwork.py",
2009
+ "nodes_upscale_model.py",
2010
+ "nodes_post_processing.py",
2011
+ "nodes_mask.py",
2012
+ "nodes_compositing.py",
2013
+ "nodes_rebatch.py",
2014
+ "nodes_model_merging.py",
2015
+ "nodes_tomesd.py",
2016
+ "nodes_clip_sdxl.py",
2017
+ "nodes_canny.py",
2018
+ "nodes_freelunch.py",
2019
+ "nodes_custom_sampler.py",
2020
+ "nodes_hypertile.py",
2021
+ "nodes_model_advanced.py",
2022
+ "nodes_model_downscale.py",
2023
+ "nodes_images.py",
2024
+ "nodes_video_model.py",
2025
+ "nodes_sag.py",
2026
+ "nodes_perpneg.py",
2027
+ "nodes_stable3d.py",
2028
+ "nodes_sdupscale.py",
2029
+ "nodes_photomaker.py",
2030
+ "nodes_cond.py",
2031
+ "nodes_morphology.py",
2032
+ "nodes_stable_cascade.py",
2033
+ "nodes_differential_diffusion.py",
2034
+ "nodes_ip2p.py",
2035
+ "nodes_model_merging_model_specific.py",
2036
+ "nodes_pag.py",
2037
+ "nodes_align_your_steps.py",
2038
+ "nodes_attention_multiply.py",
2039
+ "nodes_advanced_samplers.py",
2040
+ "nodes_webcam.py",
2041
+ "nodes_audio.py",
2042
+ "nodes_sd3.py",
2043
+ "nodes_gits.py",
2044
+ "nodes_controlnet.py",
2045
+ "nodes_hunyuan.py",
2046
+ ]
2047
+
2048
+ import_failed = []
2049
+ for node_file in extras_files:
2050
+ if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="totoro_extras"):
2051
+ import_failed.append(node_file)
2052
+
2053
+ return import_failed
2054
+
2055
+
2056
+ def init_extra_nodes(init_custom_nodes=True):
2057
+ import_failed = init_builtin_extra_nodes()
2058
+
2059
+ if init_custom_nodes:
2060
+ init_external_custom_nodes()
2061
+ else:
2062
+ logging.info("Skipping loading of custom nodes")
2063
+
2064
+ if len(import_failed) > 0:
2065
+ logging.warning("WARNING: some totoro_extras/ nodes did not import correctly. This may be because they are missing some dependencies.\n")
2066
+ for node in import_failed:
2067
+ logging.warning("IMPORT FAILED: {}".format(node))
2068
+ logging.warning("\nThis issue might be caused by new missing dependencies added the last time you updated totoroUI.")
2069
+ if args.windows_standalone_build:
2070
+ logging.warning("Please run the update script: update/update_totoroui.bat")
2071
+ else:
2072
+ logging.warning("Please do a: pip install -r requirements.txt")
2073
+ logging.warning("")
content/flux/totoro/__pycache__/cli_args.cpython-311.pyc ADDED
Binary file (14.3 kB). View file
 
content/flux/totoro/__pycache__/diffusers_load.cpython-311.pyc ADDED
Binary file (2.36 kB). View file
 
content/flux/totoro/__pycache__/model_management.cpython-311.pyc ADDED
Binary file (40.8 kB). View file
 
content/flux/totoro/__pycache__/options.cpython-311.pyc ADDED
Binary file (320 Bytes). View file
 
content/flux/totoro/__pycache__/sd.cpython-311.pyc ADDED
Binary file (47.3 kB). View file
 
content/flux/totoro/checkpoint_pickle.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ load = pickle.load
4
+
5
+ class Empty:
6
+ pass
7
+
8
+ class Unpickler(pickle.Unpickler):
9
+ def find_class(self, module, name):
10
+ #TODO: safe unpickle
11
+ if module.startswith("pytorch_lightning"):
12
+ return Empty
13
+ return super().find_class(module, name)
content/flux/totoro/cldm/cldm.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #taken from: https://github.com/lllyasviel/ControlNet
2
+ #and modified
3
+
4
+ import torch
5
+ import torch as th
6
+ import torch.nn as nn
7
+
8
+ from ..ldm.modules.diffusionmodules.util import (
9
+ zero_module,
10
+ timestep_embedding,
11
+ )
12
+
13
+ from ..ldm.modules.attention import SpatialTransformer
14
+ from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
15
+ from ..ldm.util import exists
16
+ from .control_types import UNION_CONTROLNET_TYPES
17
+ from collections import OrderedDict
18
+ import totoro.ops
19
+ from totoro.ldm.modules.attention import optimized_attention
20
+
21
+ class OptimizedAttention(nn.Module):
22
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
23
+ super().__init__()
24
+ self.heads = nhead
25
+ self.c = c
26
+
27
+ self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
28
+ self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
29
+
30
+ def forward(self, x):
31
+ x = self.in_proj(x)
32
+ q, k, v = x.split(self.c, dim=2)
33
+ out = optimized_attention(q, k, v, self.heads)
34
+ return self.out_proj(out)
35
+
36
+ class QuickGELU(nn.Module):
37
+ def forward(self, x: torch.Tensor):
38
+ return x * torch.sigmoid(1.702 * x)
39
+
40
+ class ResBlockUnionControlnet(nn.Module):
41
+ def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
42
+ super().__init__()
43
+ self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
44
+ self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
45
+ self.mlp = nn.Sequential(
46
+ OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
47
+ ("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
48
+ self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)
49
+
50
+ def attention(self, x: torch.Tensor):
51
+ return self.attn(x)
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ x = x + self.attention(self.ln_1(x))
55
+ x = x + self.mlp(self.ln_2(x))
56
+ return x
57
+
58
+ class ControlledUnetModel(UNetModel):
59
+ #implemented in the ldm unet
60
+ pass
61
+
62
+ class ControlNet(nn.Module):
63
+ def __init__(
64
+ self,
65
+ image_size,
66
+ in_channels,
67
+ model_channels,
68
+ hint_channels,
69
+ num_res_blocks,
70
+ dropout=0,
71
+ channel_mult=(1, 2, 4, 8),
72
+ conv_resample=True,
73
+ dims=2,
74
+ num_classes=None,
75
+ use_checkpoint=False,
76
+ dtype=torch.float32,
77
+ num_heads=-1,
78
+ num_head_channels=-1,
79
+ num_heads_upsample=-1,
80
+ use_scale_shift_norm=False,
81
+ resblock_updown=False,
82
+ use_new_attention_order=False,
83
+ use_spatial_transformer=False, # custom transformer support
84
+ transformer_depth=1, # custom transformer support
85
+ context_dim=None, # custom transformer support
86
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
87
+ legacy=True,
88
+ disable_self_attentions=None,
89
+ num_attention_blocks=None,
90
+ disable_middle_self_attn=False,
91
+ use_linear_in_transformer=False,
92
+ adm_in_channels=None,
93
+ transformer_depth_middle=None,
94
+ transformer_depth_output=None,
95
+ attn_precision=None,
96
+ union_controlnet_num_control_type=None,
97
+ device=None,
98
+ operations=totoro.ops.disable_weight_init,
99
+ **kwargs,
100
+ ):
101
+ super().__init__()
102
+ assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
103
+ if use_spatial_transformer:
104
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
105
+
106
+ if context_dim is not None:
107
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
108
+ # from omegaconf.listconfig import ListConfig
109
+ # if type(context_dim) == ListConfig:
110
+ # context_dim = list(context_dim)
111
+
112
+ if num_heads_upsample == -1:
113
+ num_heads_upsample = num_heads
114
+
115
+ if num_heads == -1:
116
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
117
+
118
+ if num_head_channels == -1:
119
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
120
+
121
+ self.dims = dims
122
+ self.image_size = image_size
123
+ self.in_channels = in_channels
124
+ self.model_channels = model_channels
125
+
126
+ if isinstance(num_res_blocks, int):
127
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
128
+ else:
129
+ if len(num_res_blocks) != len(channel_mult):
130
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
131
+ "as a list/tuple (per-level) with the same length as channel_mult")
132
+ self.num_res_blocks = num_res_blocks
133
+
134
+ if disable_self_attentions is not None:
135
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
136
+ assert len(disable_self_attentions) == len(channel_mult)
137
+ if num_attention_blocks is not None:
138
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
139
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
140
+
141
+ transformer_depth = transformer_depth[:]
142
+
143
+ self.dropout = dropout
144
+ self.channel_mult = channel_mult
145
+ self.conv_resample = conv_resample
146
+ self.num_classes = num_classes
147
+ self.use_checkpoint = use_checkpoint
148
+ self.dtype = dtype
149
+ self.num_heads = num_heads
150
+ self.num_head_channels = num_head_channels
151
+ self.num_heads_upsample = num_heads_upsample
152
+ self.predict_codebook_ids = n_embed is not None
153
+
154
+ time_embed_dim = model_channels * 4
155
+ self.time_embed = nn.Sequential(
156
+ operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
157
+ nn.SiLU(),
158
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
159
+ )
160
+
161
+ if self.num_classes is not None:
162
+ if isinstance(self.num_classes, int):
163
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
164
+ elif self.num_classes == "continuous":
165
+ print("setting up linear c_adm embedding layer")
166
+ self.label_emb = nn.Linear(1, time_embed_dim)
167
+ elif self.num_classes == "sequential":
168
+ assert adm_in_channels is not None
169
+ self.label_emb = nn.Sequential(
170
+ nn.Sequential(
171
+ operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
172
+ nn.SiLU(),
173
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
174
+ )
175
+ )
176
+ else:
177
+ raise ValueError()
178
+
179
+ self.input_blocks = nn.ModuleList(
180
+ [
181
+ TimestepEmbedSequential(
182
+ operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
183
+ )
184
+ ]
185
+ )
186
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
187
+
188
+ self.input_hint_block = TimestepEmbedSequential(
189
+ operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
190
+ nn.SiLU(),
191
+ operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
192
+ nn.SiLU(),
193
+ operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
194
+ nn.SiLU(),
195
+ operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
196
+ nn.SiLU(),
197
+ operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
198
+ nn.SiLU(),
199
+ operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
200
+ nn.SiLU(),
201
+ operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
202
+ nn.SiLU(),
203
+ operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
204
+ )
205
+
206
+ self._feature_size = model_channels
207
+ input_block_chans = [model_channels]
208
+ ch = model_channels
209
+ ds = 1
210
+ for level, mult in enumerate(channel_mult):
211
+ for nr in range(self.num_res_blocks[level]):
212
+ layers = [
213
+ ResBlock(
214
+ ch,
215
+ time_embed_dim,
216
+ dropout,
217
+ out_channels=mult * model_channels,
218
+ dims=dims,
219
+ use_checkpoint=use_checkpoint,
220
+ use_scale_shift_norm=use_scale_shift_norm,
221
+ dtype=self.dtype,
222
+ device=device,
223
+ operations=operations,
224
+ )
225
+ ]
226
+ ch = mult * model_channels
227
+ num_transformers = transformer_depth.pop(0)
228
+ if num_transformers > 0:
229
+ if num_head_channels == -1:
230
+ dim_head = ch // num_heads
231
+ else:
232
+ num_heads = ch // num_head_channels
233
+ dim_head = num_head_channels
234
+ if legacy:
235
+ #num_heads = 1
236
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
237
+ if exists(disable_self_attentions):
238
+ disabled_sa = disable_self_attentions[level]
239
+ else:
240
+ disabled_sa = False
241
+
242
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
243
+ layers.append(
244
+ SpatialTransformer(
245
+ ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
246
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
247
+ use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
248
+ )
249
+ )
250
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
251
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
252
+ self._feature_size += ch
253
+ input_block_chans.append(ch)
254
+ if level != len(channel_mult) - 1:
255
+ out_ch = ch
256
+ self.input_blocks.append(
257
+ TimestepEmbedSequential(
258
+ ResBlock(
259
+ ch,
260
+ time_embed_dim,
261
+ dropout,
262
+ out_channels=out_ch,
263
+ dims=dims,
264
+ use_checkpoint=use_checkpoint,
265
+ use_scale_shift_norm=use_scale_shift_norm,
266
+ down=True,
267
+ dtype=self.dtype,
268
+ device=device,
269
+ operations=operations
270
+ )
271
+ if resblock_updown
272
+ else Downsample(
273
+ ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
274
+ )
275
+ )
276
+ )
277
+ ch = out_ch
278
+ input_block_chans.append(ch)
279
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
280
+ ds *= 2
281
+ self._feature_size += ch
282
+
283
+ if num_head_channels == -1:
284
+ dim_head = ch // num_heads
285
+ else:
286
+ num_heads = ch // num_head_channels
287
+ dim_head = num_head_channels
288
+ if legacy:
289
+ #num_heads = 1
290
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
291
+ mid_block = [
292
+ ResBlock(
293
+ ch,
294
+ time_embed_dim,
295
+ dropout,
296
+ dims=dims,
297
+ use_checkpoint=use_checkpoint,
298
+ use_scale_shift_norm=use_scale_shift_norm,
299
+ dtype=self.dtype,
300
+ device=device,
301
+ operations=operations
302
+ )]
303
+ if transformer_depth_middle >= 0:
304
+ mid_block += [SpatialTransformer( # always uses a self-attn
305
+ ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
306
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
307
+ use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
308
+ ),
309
+ ResBlock(
310
+ ch,
311
+ time_embed_dim,
312
+ dropout,
313
+ dims=dims,
314
+ use_checkpoint=use_checkpoint,
315
+ use_scale_shift_norm=use_scale_shift_norm,
316
+ dtype=self.dtype,
317
+ device=device,
318
+ operations=operations
319
+ )]
320
+ self.middle_block = TimestepEmbedSequential(*mid_block)
321
+ self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
322
+ self._feature_size += ch
323
+
324
+ if union_controlnet_num_control_type is not None:
325
+ self.num_control_type = union_controlnet_num_control_type
326
+ num_trans_channel = 320
327
+ num_trans_head = 8
328
+ num_trans_layer = 1
329
+ num_proj_channel = 320
330
+ # task_scale_factor = num_trans_channel ** 0.5
331
+ self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device))
332
+
333
+ self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)])
334
+ self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device)
335
+ #-----------------------------------------------------------------------------------------------------
336
+
337
+ control_add_embed_dim = 256
338
+ class ControlAddEmbedding(nn.Module):
339
+ def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None):
340
+ super().__init__()
341
+ self.num_control_type = num_control_type
342
+ self.in_dim = in_dim
343
+ self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
344
+ self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)
345
+ def forward(self, control_type, dtype, device):
346
+ c_type = torch.zeros((self.num_control_type,), device=device)
347
+ c_type[control_type] = 1.0
348
+ c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
349
+ return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))
350
+
351
+ self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)
352
+ else:
353
+ self.task_embedding = None
354
+ self.control_add_embedding = None
355
+
356
+ def union_controlnet_merge(self, hint, control_type, emb, context):
357
+ # Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
358
+ inputs = []
359
+ condition_list = []
360
+
361
+ for idx in range(min(1, len(control_type))):
362
+ controlnet_cond = self.input_hint_block(hint[idx], emb, context)
363
+ feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
364
+ if idx < len(control_type):
365
+ feat_seq += self.task_embedding[control_type[idx]].to(dtype=feat_seq.dtype, device=feat_seq.device)
366
+
367
+ inputs.append(feat_seq.unsqueeze(1))
368
+ condition_list.append(controlnet_cond)
369
+
370
+ x = torch.cat(inputs, dim=1)
371
+ x = self.transformer_layes(x)
372
+ controlnet_cond_fuser = None
373
+ for idx in range(len(control_type)):
374
+ alpha = self.spatial_ch_projs(x[:, idx])
375
+ alpha = alpha.unsqueeze(-1).unsqueeze(-1)
376
+ o = condition_list[idx] + alpha
377
+ if controlnet_cond_fuser is None:
378
+ controlnet_cond_fuser = o
379
+ else:
380
+ controlnet_cond_fuser += o
381
+ return controlnet_cond_fuser
382
+
383
+ def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
384
+ return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
385
+
386
+ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
387
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
388
+ emb = self.time_embed(t_emb)
389
+
390
+ guided_hint = None
391
+ if self.control_add_embedding is not None: #Union Controlnet
392
+ control_type = kwargs.get("control_type", [])
393
+
394
+ if any([c >= self.num_control_type for c in control_type]):
395
+ max_type = max(control_type)
396
+ max_type_name = {
397
+ v: k for k, v in UNION_CONTROLNET_TYPES.items()
398
+ }[max_type]
399
+ raise ValueError(
400
+ f"Control type {max_type_name}({max_type}) is out of range for the number of control types" +
401
+ f"({self.num_control_type}) supported.\n" +
402
+ "Please consider using the ProMax ControlNet Union model.\n" +
403
+ "https://huggingface.co/xinsir/controlnet-union-sdxl-1.0/tree/main"
404
+ )
405
+
406
+ emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
407
+ if len(control_type) > 0:
408
+ if len(hint.shape) < 5:
409
+ hint = hint.unsqueeze(dim=0)
410
+ guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)
411
+
412
+ if guided_hint is None:
413
+ guided_hint = self.input_hint_block(hint, emb, context)
414
+
415
+ out_output = []
416
+ out_middle = []
417
+
418
+ hs = []
419
+ if self.num_classes is not None:
420
+ assert y.shape[0] == x.shape[0]
421
+ emb = emb + self.label_emb(y)
422
+
423
+ h = x
424
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
425
+ if guided_hint is not None:
426
+ h = module(h, emb, context)
427
+ h += guided_hint
428
+ guided_hint = None
429
+ else:
430
+ h = module(h, emb, context)
431
+ out_output.append(zero_conv(h, emb, context))
432
+
433
+ h = self.middle_block(h, emb, context)
434
+ out_middle.append(self.middle_block_out(h, emb, context))
435
+
436
+ return {"middle": out_middle, "output": out_output}
437
+
content/flux/totoro/cldm/control_types.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ UNION_CONTROLNET_TYPES = {
2
+ "openpose": 0,
3
+ "depth": 1,
4
+ "hed/pidi/scribble/ted": 2,
5
+ "canny/lineart/anime_lineart/mlsd": 3,
6
+ "normal": 4,
7
+ "segment": 5,
8
+ "tile": 6,
9
+ "repaint": 7,
10
+ }
content/flux/totoro/cldm/mmdit.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Dict, Optional
3
+ import totoro.ldm.modules.diffusionmodules.mmdit
4
+
5
+ class ControlNet(totoro.ldm.modules.diffusionmodules.mmdit.MMDiT):
6
+ def __init__(
7
+ self,
8
+ num_blocks = None,
9
+ dtype = None,
10
+ device = None,
11
+ operations = None,
12
+ **kwargs,
13
+ ):
14
+ super().__init__(dtype=dtype, device=device, operations=operations, final_layer=False, num_blocks=num_blocks, **kwargs)
15
+ # controlnet_blocks
16
+ self.controlnet_blocks = torch.nn.ModuleList([])
17
+ for _ in range(len(self.joint_blocks)):
18
+ self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
19
+
20
+ self.pos_embed_input = totoro.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
21
+ None,
22
+ self.patch_size,
23
+ self.in_channels,
24
+ self.hidden_size,
25
+ bias=True,
26
+ strict_img_size=False,
27
+ dtype=dtype,
28
+ device=device,
29
+ operations=operations
30
+ )
31
+
32
+ def forward(
33
+ self,
34
+ x: torch.Tensor,
35
+ timesteps: torch.Tensor,
36
+ y: Optional[torch.Tensor] = None,
37
+ context: Optional[torch.Tensor] = None,
38
+ hint = None,
39
+ ) -> torch.Tensor:
40
+
41
+ #weird sd3 controlnet specific stuff
42
+ y = torch.zeros_like(y)
43
+
44
+ if self.context_processor is not None:
45
+ context = self.context_processor(context)
46
+
47
+ hw = x.shape[-2:]
48
+ x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
49
+ x += self.pos_embed_input(hint)
50
+
51
+ c = self.t_embedder(timesteps, dtype=x.dtype)
52
+ if y is not None and self.y_embedder is not None:
53
+ y = self.y_embedder(y)
54
+ c = c + y
55
+
56
+ if context is not None:
57
+ context = self.context_embedder(context)
58
+
59
+ output = []
60
+
61
+ blocks = len(self.joint_blocks)
62
+ for i in range(blocks):
63
+ context, x = self.joint_blocks[i](
64
+ context,
65
+ x,
66
+ c=c,
67
+ use_checkpoint=self.use_checkpoint,
68
+ )
69
+
70
+ out = self.controlnet_blocks[i](x)
71
+ count = self.depth // blocks
72
+ if i == blocks - 1:
73
+ count -= 1
74
+ for j in range(count):
75
+ output.append(out)
76
+
77
+ return {"output": output}
content/flux/totoro/cli_args.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import enum
3
+ import os
4
+ from typing import Optional
5
+ import totoro.options
6
+
7
+
8
+ class EnumAction(argparse.Action):
9
+ """
10
+ Argparse action for handling Enums
11
+ """
12
+ def __init__(self, **kwargs):
13
+ # Pop off the type value
14
+ enum_type = kwargs.pop("type", None)
15
+
16
+ # Ensure an Enum subclass is provided
17
+ if enum_type is None:
18
+ raise ValueError("type must be assigned an Enum when using EnumAction")
19
+ if not issubclass(enum_type, enum.Enum):
20
+ raise TypeError("type must be an Enum when using EnumAction")
21
+
22
+ # Generate choices from the Enum
23
+ choices = tuple(e.value for e in enum_type)
24
+ kwargs.setdefault("choices", choices)
25
+ kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
26
+
27
+ super(EnumAction, self).__init__(**kwargs)
28
+
29
+ self._enum = enum_type
30
+
31
+ def __call__(self, parser, namespace, values, option_string=None):
32
+ # Convert value back into an Enum
33
+ value = self._enum(values)
34
+ setattr(namespace, self.dest, value)
35
+
36
+
37
+ parser = argparse.ArgumentParser()
38
+
39
+ parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
40
+ parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
41
+ parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
42
+ parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
43
+ parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
44
+ parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
45
+
46
+ parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
47
+ parser.add_argument("--output-directory", type=str, default=None, help="Set the totoroUI output directory.")
48
+ parser.add_argument("--temp-directory", type=str, default=None, help="Set the totoroUI temp directory (default is in the totoroUI directory).")
49
+ parser.add_argument("--input-directory", type=str, default=None, help="Set the totoroUI input directory.")
50
+ parser.add_argument("--auto-launch", action="store_true", help="Automatically launch totoroUI in the default browser.")
51
+ parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
52
+ parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
53
+ cm_group = parser.add_mutually_exclusive_group()
54
+ cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
55
+ cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
56
+
57
+
58
+ fp_group = parser.add_mutually_exclusive_group()
59
+ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
60
+ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
61
+
62
+ fpunet_group = parser.add_mutually_exclusive_group()
63
+ fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
64
+ fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
65
+ fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
66
+ fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
67
+
68
+ fpvae_group = parser.add_mutually_exclusive_group()
69
+ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
70
+ fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
71
+ fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
72
+
73
+ parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
74
+
75
+ fpte_group = parser.add_mutually_exclusive_group()
76
+ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
77
+ fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
78
+ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
79
+ fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
80
+
81
+ parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
82
+
83
+ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
84
+
85
+ parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
86
+
87
+ class LatentPreviewMethod(enum.Enum):
88
+ NoPreviews = "none"
89
+ Auto = "auto"
90
+ Latent2RGB = "latent2rgb"
91
+ TAESD = "taesd"
92
+
93
+ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
94
+
95
+ attn_group = parser.add_mutually_exclusive_group()
96
+ attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
97
+ attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
98
+ attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
99
+
100
+ parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
101
+
102
+ upcast = parser.add_mutually_exclusive_group()
103
+ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
104
+ upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
105
+
106
+
107
+ vram_group = parser.add_mutually_exclusive_group()
108
+ vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
109
+ vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
110
+ vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
111
+ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
112
+ vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
113
+ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
114
+
115
+ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
116
+
117
+ parser.add_argument("--disable-smart-memory", action="store_true", help="Force totoroUI to agressively offload to regular ram instead of keeping models in vram when it can.")
118
+ parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
119
+
120
+ parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
121
+ parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
122
+ parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
123
+
124
+ parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
125
+ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
126
+
127
+ parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
128
+
129
+ parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
130
+
131
+ # The default built-in provider hosted under web/
132
+ DEFAULT_VERSION_STRING = "totoroanonymous/totoroUI@latest"
133
+
134
+ parser.add_argument(
135
+ "--front-end-version",
136
+ type=str,
137
+ default=DEFAULT_VERSION_STRING,
138
+ help="""
139
+ Specifies the version of the frontend to be used. This command needs internet connectivity to query and
140
+ download available frontend implementations from GitHub releases.
141
+
142
+ The version string should be in the format of:
143
+ [repoOwner]/[repoName]@[version]
144
+ where version is one of: "latest" or a valid version number (e.g. "1.0.0")
145
+ """,
146
+ )
147
+
148
+ def is_valid_directory(path: Optional[str]) -> Optional[str]:
149
+ """Validate if the given path is a directory."""
150
+ if path is None:
151
+ return None
152
+
153
+ if not os.path.isdir(path):
154
+ raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
155
+ return path
156
+
157
+ parser.add_argument(
158
+ "--front-end-root",
159
+ type=is_valid_directory,
160
+ default=None,
161
+ help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
162
+ )
163
+
164
+ if totoro.options.args_parsing:
165
+ args = parser.parse_args()
166
+ else:
167
+ args = parser.parse_args([])
168
+
169
+ if args.windows_standalone_build:
170
+ args.auto_launch = True
171
+
172
+ if args.disable_auto_launch:
173
+ args.auto_launch = False
174
+
175
+ import logging
176
+ logging_level = logging.INFO
177
+ if args.verbose:
178
+ logging_level = logging.DEBUG
179
+
180
+ logging.basicConfig(format="%(message)s", level=logging_level)
content/flux/totoro/clip_config_bigg.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CLIPTextModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 0,
7
+ "dropout": 0.0,
8
+ "eos_token_id": 49407,
9
+ "hidden_act": "gelu",
10
+ "hidden_size": 1280,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 5120,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 77,
16
+ "model_type": "clip_text_model",
17
+ "num_attention_heads": 20,
18
+ "num_hidden_layers": 32,
19
+ "pad_token_id": 1,
20
+ "projection_dim": 1280,
21
+ "torch_dtype": "float32",
22
+ "vocab_size": 49408
23
+ }
content/flux/totoro/clip_model.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from totoro.ldm.modules.attention import optimized_attention_for_device
3
+ import totoro.ops
4
+
5
+ class CLIPAttention(torch.nn.Module):
6
+ def __init__(self, embed_dim, heads, dtype, device, operations):
7
+ super().__init__()
8
+
9
+ self.heads = heads
10
+ self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
11
+ self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
12
+ self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
13
+
14
+ self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
15
+
16
+ def forward(self, x, mask=None, optimized_attention=None):
17
+ q = self.q_proj(x)
18
+ k = self.k_proj(x)
19
+ v = self.v_proj(x)
20
+
21
+ out = optimized_attention(q, k, v, self.heads, mask)
22
+ return self.out_proj(out)
23
+
24
+ ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
25
+ "gelu": torch.nn.functional.gelu,
26
+ }
27
+
28
+ class CLIPMLP(torch.nn.Module):
29
+ def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
30
+ super().__init__()
31
+ self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
32
+ self.activation = ACTIVATIONS[activation]
33
+ self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.activation(x)
38
+ x = self.fc2(x)
39
+ return x
40
+
41
+ class CLIPLayer(torch.nn.Module):
42
+ def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
43
+ super().__init__()
44
+ self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
45
+ self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
46
+ self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
47
+ self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
48
+
49
+ def forward(self, x, mask=None, optimized_attention=None):
50
+ x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
51
+ x += self.mlp(self.layer_norm2(x))
52
+ return x
53
+
54
+
55
+ class CLIPEncoder(torch.nn.Module):
56
+ def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
57
+ super().__init__()
58
+ self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
59
+
60
+ def forward(self, x, mask=None, intermediate_output=None):
61
+ optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
62
+
63
+ if intermediate_output is not None:
64
+ if intermediate_output < 0:
65
+ intermediate_output = len(self.layers) + intermediate_output
66
+
67
+ intermediate = None
68
+ for i, l in enumerate(self.layers):
69
+ x = l(x, mask, optimized_attention)
70
+ if i == intermediate_output:
71
+ intermediate = x.clone()
72
+ return x, intermediate
73
+
74
+ class CLIPEmbeddings(torch.nn.Module):
75
+ def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, operations=None):
76
+ super().__init__()
77
+ self.token_embedding = operations.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
78
+ self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
79
+
80
+ def forward(self, input_tokens, dtype=torch.float32):
81
+ return self.token_embedding(input_tokens, out_dtype=dtype) + totoro.ops.cast_to(self.position_embedding.weight, dtype=dtype, device=input_tokens.device)
82
+
83
+
84
+ class CLIPTextModel_(torch.nn.Module):
85
+ def __init__(self, config_dict, dtype, device, operations):
86
+ num_layers = config_dict["num_hidden_layers"]
87
+ embed_dim = config_dict["hidden_size"]
88
+ heads = config_dict["num_attention_heads"]
89
+ intermediate_size = config_dict["intermediate_size"]
90
+ intermediate_activation = config_dict["hidden_act"]
91
+ self.eos_token_id = config_dict["eos_token_id"]
92
+
93
+ super().__init__()
94
+ self.embeddings = CLIPEmbeddings(embed_dim, dtype=dtype, device=device, operations=operations)
95
+ self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
96
+ self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
97
+
98
+ def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
99
+ x = self.embeddings(input_tokens, dtype=dtype)
100
+ mask = None
101
+ if attention_mask is not None:
102
+ mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
103
+ mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
104
+
105
+ causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
106
+ if mask is not None:
107
+ mask += causal_mask
108
+ else:
109
+ mask = causal_mask
110
+
111
+ x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
112
+ x = self.final_layer_norm(x)
113
+ if i is not None and final_layer_norm_intermediate:
114
+ i = self.final_layer_norm(i)
115
+
116
+ pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
117
+ return x, i, pooled_output
118
+
119
+ class CLIPTextModel(torch.nn.Module):
120
+ def __init__(self, config_dict, dtype, device, operations):
121
+ super().__init__()
122
+ self.num_layers = config_dict["num_hidden_layers"]
123
+ self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
124
+ embed_dim = config_dict["hidden_size"]
125
+ self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
126
+ self.text_projection.weight.copy_(torch.eye(embed_dim))
127
+ self.dtype = dtype
128
+
129
+ def get_input_embeddings(self):
130
+ return self.text_model.embeddings.token_embedding
131
+
132
+ def set_input_embeddings(self, embeddings):
133
+ self.text_model.embeddings.token_embedding = embeddings
134
+
135
+ def forward(self, *args, **kwargs):
136
+ x = self.text_model(*args, **kwargs)
137
+ out = self.text_projection(x[2])
138
+ return (x[0], x[1], out, x[2])
139
+
140
+
141
+ class CLIPVisionEmbeddings(torch.nn.Module):
142
+ def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
143
+ super().__init__()
144
+ self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
145
+
146
+ self.patch_embedding = operations.Conv2d(
147
+ in_channels=num_channels,
148
+ out_channels=embed_dim,
149
+ kernel_size=patch_size,
150
+ stride=patch_size,
151
+ bias=False,
152
+ dtype=dtype,
153
+ device=device
154
+ )
155
+
156
+ num_patches = (image_size // patch_size) ** 2
157
+ num_positions = num_patches + 1
158
+ self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
159
+
160
+ def forward(self, pixel_values):
161
+ embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
162
+ return torch.cat([totoro.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + totoro.ops.cast_to_input(self.position_embedding.weight, embeds)
163
+
164
+
165
+ class CLIPVision(torch.nn.Module):
166
+ def __init__(self, config_dict, dtype, device, operations):
167
+ super().__init__()
168
+ num_layers = config_dict["num_hidden_layers"]
169
+ embed_dim = config_dict["hidden_size"]
170
+ heads = config_dict["num_attention_heads"]
171
+ intermediate_size = config_dict["intermediate_size"]
172
+ intermediate_activation = config_dict["hidden_act"]
173
+
174
+ self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations)
175
+ self.pre_layrnorm = operations.LayerNorm(embed_dim)
176
+ self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
177
+ self.post_layernorm = operations.LayerNorm(embed_dim)
178
+
179
+ def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
180
+ x = self.embeddings(pixel_values)
181
+ x = self.pre_layrnorm(x)
182
+ #TODO: attention_mask?
183
+ x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
184
+ pooled_output = self.post_layernorm(x[:, 0, :])
185
+ return x, i, pooled_output
186
+
187
+ class CLIPVisionModelProjection(torch.nn.Module):
188
+ def __init__(self, config_dict, dtype, device, operations):
189
+ super().__init__()
190
+ self.vision_model = CLIPVision(config_dict, dtype, device, operations)
191
+ self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
192
+
193
+ def forward(self, *args, **kwargs):
194
+ x = self.vision_model(*args, **kwargs)
195
+ out = self.visual_projection(x[2])
196
+ return (x[0], x[1], out)
content/flux/totoro/clip_vision.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
2
+ import os
3
+ import torch
4
+ import json
5
+ import logging
6
+
7
+ import totoro.ops
8
+ import totoro.model_patcher
9
+ import totoro.model_management
10
+ import totoro.utils
11
+ import totoro.clip_model
12
+
13
+ class Output:
14
+ def __getitem__(self, key):
15
+ return getattr(self, key)
16
+ def __setitem__(self, key, item):
17
+ setattr(self, key, item)
18
+
19
+ def clip_preprocess(image, size=224):
20
+ mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
21
+ std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
22
+ image = image.movedim(-1, 1)
23
+ if not (image.shape[2] == size and image.shape[3] == size):
24
+ scale = (size / min(image.shape[2], image.shape[3]))
25
+ image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
26
+ h = (image.shape[2] - size)//2
27
+ w = (image.shape[3] - size)//2
28
+ image = image[:,:,h:h+size,w:w+size]
29
+ image = torch.clip((255. * image), 0, 255).round() / 255.0
30
+ return (image - mean.view([3,1,1])) / std.view([3,1,1])
31
+
32
+ class ClipVisionModel():
33
+ def __init__(self, json_config):
34
+ with open(json_config) as f:
35
+ config = json.load(f)
36
+
37
+ self.image_size = config.get("image_size", 224)
38
+ self.load_device = totoro.model_management.text_encoder_device()
39
+ offload_device = totoro.model_management.text_encoder_offload_device()
40
+ self.dtype = totoro.model_management.text_encoder_dtype(self.load_device)
41
+ self.model = totoro.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, totoro.ops.manual_cast)
42
+ self.model.eval()
43
+
44
+ self.patcher = totoro.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
45
+
46
+ def load_sd(self, sd):
47
+ return self.model.load_state_dict(sd, strict=False)
48
+
49
+ def get_sd(self):
50
+ return self.model.state_dict()
51
+
52
+ def encode_image(self, image):
53
+ totoro.model_management.load_model_gpu(self.patcher)
54
+ pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
55
+ out = self.model(pixel_values=pixel_values, intermediate_output=-2)
56
+
57
+ outputs = Output()
58
+ outputs["last_hidden_state"] = out[0].to(totoro.model_management.intermediate_device())
59
+ outputs["image_embeds"] = out[2].to(totoro.model_management.intermediate_device())
60
+ outputs["penultimate_hidden_states"] = out[1].to(totoro.model_management.intermediate_device())
61
+ return outputs
62
+
63
+ def convert_to_transformers(sd, prefix):
64
+ sd_k = sd.keys()
65
+ if "{}transformer.resblocks.0.attn.in_proj_weight".format(prefix) in sd_k:
66
+ keys_to_replace = {
67
+ "{}class_embedding".format(prefix): "vision_model.embeddings.class_embedding",
68
+ "{}conv1.weight".format(prefix): "vision_model.embeddings.patch_embedding.weight",
69
+ "{}positional_embedding".format(prefix): "vision_model.embeddings.position_embedding.weight",
70
+ "{}ln_post.bias".format(prefix): "vision_model.post_layernorm.bias",
71
+ "{}ln_post.weight".format(prefix): "vision_model.post_layernorm.weight",
72
+ "{}ln_pre.bias".format(prefix): "vision_model.pre_layrnorm.bias",
73
+ "{}ln_pre.weight".format(prefix): "vision_model.pre_layrnorm.weight",
74
+ }
75
+
76
+ for x in keys_to_replace:
77
+ if x in sd_k:
78
+ sd[keys_to_replace[x]] = sd.pop(x)
79
+
80
+ if "{}proj".format(prefix) in sd_k:
81
+ sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
82
+
83
+ sd = transformers_convert(sd, prefix, "vision_model.", 48)
84
+ else:
85
+ replace_prefix = {prefix: ""}
86
+ sd = state_dict_prefix_replace(sd, replace_prefix)
87
+ return sd
88
+
89
+ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
90
+ if convert_keys:
91
+ sd = convert_to_transformers(sd, prefix)
92
+ if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
93
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
94
+ elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
95
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
96
+ elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
97
+ if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
98
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
99
+ else:
100
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
101
+ else:
102
+ return None
103
+
104
+ clip = ClipVisionModel(json_config)
105
+ m, u = clip.load_sd(sd)
106
+ if len(m) > 0:
107
+ logging.warning("missing clip vision: {}".format(m))
108
+ u = set(u)
109
+ keys = list(sd.keys())
110
+ for k in keys:
111
+ if k not in u:
112
+ t = sd.pop(k)
113
+ del t
114
+ return clip
115
+
116
+ def load(ckpt_path):
117
+ sd = load_torch_file(ckpt_path)
118
+ if "visual.transformer.resblocks.0.attn.in_proj_weight" in sd:
119
+ return load_clipvision_from_sd(sd, prefix="visual.", convert_keys=True)
120
+ else:
121
+ return load_clipvision_from_sd(sd)
content/flux/totoro/clip_vision_config_g.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "gelu",
5
+ "hidden_size": 1664,
6
+ "image_size": 224,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 8192,
10
+ "layer_norm_eps": 1e-05,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 48,
15
+ "patch_size": 14,
16
+ "projection_dim": 1280,
17
+ "torch_dtype": "float32"
18
+ }
content/flux/totoro/clip_vision_config_h.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "gelu",
5
+ "hidden_size": 1280,
6
+ "image_size": 224,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 5120,
10
+ "layer_norm_eps": 1e-05,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 32,
15
+ "patch_size": 14,
16
+ "projection_dim": 1024,
17
+ "torch_dtype": "float32"
18
+ }
content/flux/totoro/clip_vision_config_vitl.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "quick_gelu",
5
+ "hidden_size": 1024,
6
+ "image_size": 224,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 4096,
10
+ "layer_norm_eps": 1e-05,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 24,
15
+ "patch_size": 14,
16
+ "projection_dim": 768,
17
+ "torch_dtype": "float32"
18
+ }
content/flux/totoro/clip_vision_config_vitl_336.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "quick_gelu",
5
+ "hidden_size": 1024,
6
+ "image_size": 336,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 4096,
10
+ "layer_norm_eps": 1e-5,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 24,
15
+ "patch_size": 14,
16
+ "projection_dim": 768,
17
+ "torch_dtype": "float32"
18
+ }
content/flux/totoro/conds.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import totoro.utils
4
+
5
+
6
+ def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
7
+ return abs(a*b) // math.gcd(a, b)
8
+
9
+ class CONDRegular:
10
+ def __init__(self, cond):
11
+ self.cond = cond
12
+
13
+ def _copy_with(self, cond):
14
+ return self.__class__(cond)
15
+
16
+ def process_cond(self, batch_size, device, **kwargs):
17
+ return self._copy_with(totoro.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
18
+
19
+ def can_concat(self, other):
20
+ if self.cond.shape != other.cond.shape:
21
+ return False
22
+ return True
23
+
24
+ def concat(self, others):
25
+ conds = [self.cond]
26
+ for x in others:
27
+ conds.append(x.cond)
28
+ return torch.cat(conds)
29
+
30
+ class CONDNoiseShape(CONDRegular):
31
+ def process_cond(self, batch_size, device, area, **kwargs):
32
+ data = self.cond
33
+ if area is not None:
34
+ dims = len(area) // 2
35
+ for i in range(dims):
36
+ data = data.narrow(i + 2, area[i + dims], area[i])
37
+
38
+ return self._copy_with(totoro.utils.repeat_to_batch_size(data, batch_size).to(device))
39
+
40
+
41
+ class CONDCrossAttn(CONDRegular):
42
+ def can_concat(self, other):
43
+ s1 = self.cond.shape
44
+ s2 = other.cond.shape
45
+ if s1 != s2:
46
+ if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
47
+ return False
48
+
49
+ mult_min = lcm(s1[1], s2[1])
50
+ diff = mult_min // min(s1[1], s2[1])
51
+ if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
52
+ return False
53
+ return True
54
+
55
+ def concat(self, others):
56
+ conds = [self.cond]
57
+ crossattn_max_len = self.cond.shape[1]
58
+ for x in others:
59
+ c = x.cond
60
+ crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
61
+ conds.append(c)
62
+
63
+ out = []
64
+ for c in conds:
65
+ if c.shape[1] < crossattn_max_len:
66
+ c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
67
+ out.append(c)
68
+ return torch.cat(out)
69
+
70
+ class CONDConstant(CONDRegular):
71
+ def __init__(self, cond):
72
+ self.cond = cond
73
+
74
+ def process_cond(self, batch_size, device, **kwargs):
75
+ return self._copy_with(self.cond)
76
+
77
+ def can_concat(self, other):
78
+ if self.cond != other.cond:
79
+ return False
80
+ return True
81
+
82
+ def concat(self, others):
83
+ return self.cond
content/flux/totoro/controlnet.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import os
4
+ import logging
5
+ import totoro.utils
6
+ import totoro.model_management
7
+ import totoro.model_detection
8
+ import totoro.model_patcher
9
+ import totoro.ops
10
+ import totoro.latent_formats
11
+
12
+ import totoro.cldm.cldm
13
+ import totoro.t2i_adapter.adapter
14
+ import totoro.ldm.cascade.controlnet
15
+ import totoro.cldm.mmdit
16
+
17
+
18
+ def broadcast_image_to(tensor, target_batch_size, batched_number):
19
+ current_batch_size = tensor.shape[0]
20
+ #print(current_batch_size, target_batch_size)
21
+ if current_batch_size == 1:
22
+ return tensor
23
+
24
+ per_batch = target_batch_size // batched_number
25
+ tensor = tensor[:per_batch]
26
+
27
+ if per_batch > tensor.shape[0]:
28
+ tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0)
29
+
30
+ current_batch_size = tensor.shape[0]
31
+ if current_batch_size == target_batch_size:
32
+ return tensor
33
+ else:
34
+ return torch.cat([tensor] * batched_number, dim=0)
35
+
36
+ class ControlBase:
37
+ def __init__(self, device=None):
38
+ self.cond_hint_original = None
39
+ self.cond_hint = None
40
+ self.strength = 1.0
41
+ self.timestep_percent_range = (0.0, 1.0)
42
+ self.latent_format = None
43
+ self.vae = None
44
+ self.global_average_pooling = False
45
+ self.timestep_range = None
46
+ self.compression_ratio = 8
47
+ self.upscale_algorithm = 'nearest-exact'
48
+ self.extra_args = {}
49
+
50
+ if device is None:
51
+ device = totoro.model_management.get_torch_device()
52
+ self.device = device
53
+ self.previous_controlnet = None
54
+
55
+ def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
56
+ self.cond_hint_original = cond_hint
57
+ self.strength = strength
58
+ self.timestep_percent_range = timestep_percent_range
59
+ if self.latent_format is not None:
60
+ self.vae = vae
61
+ return self
62
+
63
+ def pre_run(self, model, percent_to_timestep_function):
64
+ self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
65
+ if self.previous_controlnet is not None:
66
+ self.previous_controlnet.pre_run(model, percent_to_timestep_function)
67
+
68
+ def set_previous_controlnet(self, controlnet):
69
+ self.previous_controlnet = controlnet
70
+ return self
71
+
72
+ def cleanup(self):
73
+ if self.previous_controlnet is not None:
74
+ self.previous_controlnet.cleanup()
75
+ if self.cond_hint is not None:
76
+ del self.cond_hint
77
+ self.cond_hint = None
78
+ self.timestep_range = None
79
+
80
+ def get_models(self):
81
+ out = []
82
+ if self.previous_controlnet is not None:
83
+ out += self.previous_controlnet.get_models()
84
+ return out
85
+
86
+ def copy_to(self, c):
87
+ c.cond_hint_original = self.cond_hint_original
88
+ c.strength = self.strength
89
+ c.timestep_percent_range = self.timestep_percent_range
90
+ c.global_average_pooling = self.global_average_pooling
91
+ c.compression_ratio = self.compression_ratio
92
+ c.upscale_algorithm = self.upscale_algorithm
93
+ c.latent_format = self.latent_format
94
+ c.extra_args = self.extra_args.copy()
95
+ c.vae = self.vae
96
+
97
+ def inference_memory_requirements(self, dtype):
98
+ if self.previous_controlnet is not None:
99
+ return self.previous_controlnet.inference_memory_requirements(dtype)
100
+ return 0
101
+
102
+ def control_merge(self, control, control_prev, output_dtype):
103
+ out = {'input':[], 'middle':[], 'output': []}
104
+
105
+ for key in control:
106
+ control_output = control[key]
107
+ applied_to = set()
108
+ for i in range(len(control_output)):
109
+ x = control_output[i]
110
+ if x is not None:
111
+ if self.global_average_pooling:
112
+ x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
113
+
114
+ if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
115
+ applied_to.add(x)
116
+ x *= self.strength
117
+
118
+ if x.dtype != output_dtype:
119
+ x = x.to(output_dtype)
120
+
121
+ out[key].append(x)
122
+
123
+ if control_prev is not None:
124
+ for x in ['input', 'middle', 'output']:
125
+ o = out[x]
126
+ for i in range(len(control_prev[x])):
127
+ prev_val = control_prev[x][i]
128
+ if i >= len(o):
129
+ o.append(prev_val)
130
+ elif prev_val is not None:
131
+ if o[i] is None:
132
+ o[i] = prev_val
133
+ else:
134
+ if o[i].shape[0] < prev_val.shape[0]:
135
+ o[i] = prev_val + o[i]
136
+ else:
137
+ o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
138
+ return out
139
+
140
+ def set_extra_arg(self, argument, value=None):
141
+ self.extra_args[argument] = value
142
+
143
+
144
+ class ControlNet(ControlBase):
145
+ def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
146
+ super().__init__(device)
147
+ self.control_model = control_model
148
+ self.load_device = load_device
149
+ if control_model is not None:
150
+ self.control_model_wrapped = totoro.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=totoro.model_management.unet_offload_device())
151
+
152
+ self.compression_ratio = compression_ratio
153
+ self.global_average_pooling = global_average_pooling
154
+ self.model_sampling_current = None
155
+ self.manual_cast_dtype = manual_cast_dtype
156
+ self.latent_format = latent_format
157
+
158
+ def get_control(self, x_noisy, t, cond, batched_number):
159
+ control_prev = None
160
+ if self.previous_controlnet is not None:
161
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
162
+
163
+ if self.timestep_range is not None:
164
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
165
+ if control_prev is not None:
166
+ return control_prev
167
+ else:
168
+ return None
169
+
170
+ dtype = self.control_model.dtype
171
+ if self.manual_cast_dtype is not None:
172
+ dtype = self.manual_cast_dtype
173
+
174
+ output_dtype = x_noisy.dtype
175
+ if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
176
+ if self.cond_hint is not None:
177
+ del self.cond_hint
178
+ self.cond_hint = None
179
+ compression_ratio = self.compression_ratio
180
+ if self.vae is not None:
181
+ compression_ratio *= self.vae.downscale_ratio
182
+ self.cond_hint = totoro.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
183
+ if self.vae is not None:
184
+ loaded_models = totoro.model_management.loaded_models(only_currently_used=True)
185
+ self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
186
+ totoro.model_management.load_models_gpu(loaded_models)
187
+ if self.latent_format is not None:
188
+ self.cond_hint = self.latent_format.process_in(self.cond_hint)
189
+ self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
190
+ if x_noisy.shape[0] != self.cond_hint.shape[0]:
191
+ self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
192
+
193
+ context = cond.get('crossattn_controlnet', cond['c_crossattn'])
194
+ y = cond.get('y', None)
195
+ if y is not None:
196
+ y = y.to(dtype)
197
+ timestep = self.model_sampling_current.timestep(t)
198
+ x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
199
+
200
+ control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args)
201
+ return self.control_merge(control, control_prev, output_dtype)
202
+
203
+ def copy(self):
204
+ c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
205
+ c.control_model = self.control_model
206
+ c.control_model_wrapped = self.control_model_wrapped
207
+ self.copy_to(c)
208
+ return c
209
+
210
+ def get_models(self):
211
+ out = super().get_models()
212
+ out.append(self.control_model_wrapped)
213
+ return out
214
+
215
+ def pre_run(self, model, percent_to_timestep_function):
216
+ super().pre_run(model, percent_to_timestep_function)
217
+ self.model_sampling_current = model.model_sampling
218
+
219
+ def cleanup(self):
220
+ self.model_sampling_current = None
221
+ super().cleanup()
222
+
223
+ class ControlLoraOps:
224
+ class Linear(torch.nn.Module, totoro.ops.CastWeightBiasOp):
225
+ def __init__(self, in_features: int, out_features: int, bias: bool = True,
226
+ device=None, dtype=None) -> None:
227
+ factory_kwargs = {'device': device, 'dtype': dtype}
228
+ super().__init__()
229
+ self.in_features = in_features
230
+ self.out_features = out_features
231
+ self.weight = None
232
+ self.up = None
233
+ self.down = None
234
+ self.bias = None
235
+
236
+ def forward(self, input):
237
+ weight, bias = totoro.ops.cast_bias_weight(self, input)
238
+ if self.up is not None:
239
+ return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
240
+ else:
241
+ return torch.nn.functional.linear(input, weight, bias)
242
+
243
+ class Conv2d(torch.nn.Module, totoro.ops.CastWeightBiasOp):
244
+ def __init__(
245
+ self,
246
+ in_channels,
247
+ out_channels,
248
+ kernel_size,
249
+ stride=1,
250
+ padding=0,
251
+ dilation=1,
252
+ groups=1,
253
+ bias=True,
254
+ padding_mode='zeros',
255
+ device=None,
256
+ dtype=None
257
+ ):
258
+ super().__init__()
259
+ self.in_channels = in_channels
260
+ self.out_channels = out_channels
261
+ self.kernel_size = kernel_size
262
+ self.stride = stride
263
+ self.padding = padding
264
+ self.dilation = dilation
265
+ self.transposed = False
266
+ self.output_padding = 0
267
+ self.groups = groups
268
+ self.padding_mode = padding_mode
269
+
270
+ self.weight = None
271
+ self.bias = None
272
+ self.up = None
273
+ self.down = None
274
+
275
+
276
+ def forward(self, input):
277
+ weight, bias = totoro.ops.cast_bias_weight(self, input)
278
+ if self.up is not None:
279
+ return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
280
+ else:
281
+ return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
282
+
283
+
284
+ class ControlLora(ControlNet):
285
+ def __init__(self, control_weights, global_average_pooling=False, device=None):
286
+ ControlBase.__init__(self, device)
287
+ self.control_weights = control_weights
288
+ self.global_average_pooling = global_average_pooling
289
+
290
+ def pre_run(self, model, percent_to_timestep_function):
291
+ super().pre_run(model, percent_to_timestep_function)
292
+ controlnet_config = model.model_config.unet_config.copy()
293
+ controlnet_config.pop("out_channels")
294
+ controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
295
+ self.manual_cast_dtype = model.manual_cast_dtype
296
+ dtype = model.get_dtype()
297
+ if self.manual_cast_dtype is None:
298
+ class control_lora_ops(ControlLoraOps, totoro.ops.disable_weight_init):
299
+ pass
300
+ else:
301
+ class control_lora_ops(ControlLoraOps, totoro.ops.manual_cast):
302
+ pass
303
+ dtype = self.manual_cast_dtype
304
+
305
+ controlnet_config["operations"] = control_lora_ops
306
+ controlnet_config["dtype"] = dtype
307
+ self.control_model = totoro.cldm.cldm.ControlNet(**controlnet_config)
308
+ self.control_model.to(totoro.model_management.get_torch_device())
309
+ diffusion_model = model.diffusion_model
310
+ sd = diffusion_model.state_dict()
311
+ cm = self.control_model.state_dict()
312
+
313
+ for k in sd:
314
+ weight = sd[k]
315
+ try:
316
+ totoro.utils.set_attr_param(self.control_model, k, weight)
317
+ except:
318
+ pass
319
+
320
+ for k in self.control_weights:
321
+ if k not in {"lora_controlnet"}:
322
+ totoro.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(totoro.model_management.get_torch_device()))
323
+
324
+ def copy(self):
325
+ c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
326
+ self.copy_to(c)
327
+ return c
328
+
329
+ def cleanup(self):
330
+ del self.control_model
331
+ self.control_model = None
332
+ super().cleanup()
333
+
334
+ def get_models(self):
335
+ out = ControlBase.get_models(self)
336
+ return out
337
+
338
+ def inference_memory_requirements(self, dtype):
339
+ return totoro.utils.calculate_parameters(self.control_weights) * totoro.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
340
+
341
+ def load_controlnet_mmdit(sd):
342
+ new_sd = totoro.model_detection.convert_diffusers_mmdit(sd, "")
343
+ model_config = totoro.model_detection.model_config_from_unet(new_sd, "", True)
344
+ num_blocks = totoro.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
345
+ for k in sd:
346
+ new_sd[k] = sd[k]
347
+
348
+ supported_inference_dtypes = model_config.supported_inference_dtypes
349
+
350
+ controlnet_config = model_config.unet_config
351
+ unet_dtype = totoro.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
352
+ load_device = totoro.model_management.get_torch_device()
353
+ manual_cast_dtype = totoro.model_management.unet_manual_cast(unet_dtype, load_device)
354
+ if manual_cast_dtype is not None:
355
+ operations = totoro.ops.manual_cast
356
+ else:
357
+ operations = totoro.ops.disable_weight_init
358
+
359
+ control_model = totoro.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config)
360
+ missing, unexpected = control_model.load_state_dict(new_sd, strict=False)
361
+
362
+ if len(missing) > 0:
363
+ logging.warning("missing controlnet keys: {}".format(missing))
364
+
365
+ if len(unexpected) > 0:
366
+ logging.debug("unexpected controlnet keys: {}".format(unexpected))
367
+
368
+ latent_format = totoro.latent_formats.SD3()
369
+ latent_format.shift_factor = 0 #SD3 controlnet weirdness
370
+ control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
371
+ return control
372
+
373
+
374
+ def load_controlnet(ckpt_path, model=None):
375
+ controlnet_data = totoro.utils.load_torch_file(ckpt_path, safe_load=True)
376
+ if "lora_controlnet" in controlnet_data:
377
+ return ControlLora(controlnet_data)
378
+
379
+ controlnet_config = None
380
+ supported_inference_dtypes = None
381
+
382
+ if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
383
+ controlnet_config = totoro.model_detection.unet_config_from_diffusers_unet(controlnet_data)
384
+ diffusers_keys = totoro.utils.unet_to_diffusers(controlnet_config)
385
+ diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
386
+ diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
387
+
388
+ count = 0
389
+ loop = True
390
+ while loop:
391
+ suffix = [".weight", ".bias"]
392
+ for s in suffix:
393
+ k_in = "controlnet_down_blocks.{}{}".format(count, s)
394
+ k_out = "zero_convs.{}.0{}".format(count, s)
395
+ if k_in not in controlnet_data:
396
+ loop = False
397
+ break
398
+ diffusers_keys[k_in] = k_out
399
+ count += 1
400
+
401
+ count = 0
402
+ loop = True
403
+ while loop:
404
+ suffix = [".weight", ".bias"]
405
+ for s in suffix:
406
+ if count == 0:
407
+ k_in = "controlnet_cond_embedding.conv_in{}".format(s)
408
+ else:
409
+ k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
410
+ k_out = "input_hint_block.{}{}".format(count * 2, s)
411
+ if k_in not in controlnet_data:
412
+ k_in = "controlnet_cond_embedding.conv_out{}".format(s)
413
+ loop = False
414
+ diffusers_keys[k_in] = k_out
415
+ count += 1
416
+
417
+ new_sd = {}
418
+ for k in diffusers_keys:
419
+ if k in controlnet_data:
420
+ new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
421
+
422
+ if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
423
+ controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0]
424
+ for k in list(controlnet_data.keys()):
425
+ new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
426
+ new_sd[new_k] = controlnet_data.pop(k)
427
+
428
+ leftover_keys = controlnet_data.keys()
429
+ if len(leftover_keys) > 0:
430
+ logging.warning("leftover keys: {}".format(leftover_keys))
431
+ controlnet_data = new_sd
432
+ elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
433
+ return load_controlnet_mmdit(controlnet_data)
434
+
435
+ pth_key = 'control_model.zero_convs.0.0.weight'
436
+ pth = False
437
+ key = 'zero_convs.0.0.weight'
438
+ if pth_key in controlnet_data:
439
+ pth = True
440
+ key = pth_key
441
+ prefix = "control_model."
442
+ elif key in controlnet_data:
443
+ prefix = ""
444
+ else:
445
+ net = load_t2i_adapter(controlnet_data)
446
+ if net is None:
447
+ logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
448
+ return net
449
+
450
+ if controlnet_config is None:
451
+ model_config = totoro.model_detection.model_config_from_unet(controlnet_data, prefix, True)
452
+ supported_inference_dtypes = model_config.supported_inference_dtypes
453
+ controlnet_config = model_config.unet_config
454
+
455
+ load_device = totoro.model_management.get_torch_device()
456
+ if supported_inference_dtypes is None:
457
+ unet_dtype = totoro.model_management.unet_dtype()
458
+ else:
459
+ unet_dtype = totoro.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
460
+
461
+ manual_cast_dtype = totoro.model_management.unet_manual_cast(unet_dtype, load_device)
462
+ if manual_cast_dtype is not None:
463
+ controlnet_config["operations"] = totoro.ops.manual_cast
464
+ controlnet_config["dtype"] = unet_dtype
465
+ controlnet_config.pop("out_channels")
466
+ controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
467
+ control_model = totoro.cldm.cldm.ControlNet(**controlnet_config)
468
+
469
+ if pth:
470
+ if 'difference' in controlnet_data:
471
+ if model is not None:
472
+ totoro.model_management.load_models_gpu([model])
473
+ model_sd = model.model_state_dict()
474
+ for x in controlnet_data:
475
+ c_m = "control_model."
476
+ if x.startswith(c_m):
477
+ sd_key = "diffusion_model.{}".format(x[len(c_m):])
478
+ if sd_key in model_sd:
479
+ cd = controlnet_data[x]
480
+ cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
481
+ else:
482
+ logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
483
+
484
+ class WeightsLoader(torch.nn.Module):
485
+ pass
486
+ w = WeightsLoader()
487
+ w.control_model = control_model
488
+ missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
489
+ else:
490
+ missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
491
+
492
+ if len(missing) > 0:
493
+ logging.warning("missing controlnet keys: {}".format(missing))
494
+
495
+ if len(unexpected) > 0:
496
+ logging.debug("unexpected controlnet keys: {}".format(unexpected))
497
+
498
+ global_average_pooling = False
499
+ filename = os.path.splitext(ckpt_path)[0]
500
+ if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
501
+ global_average_pooling = True
502
+
503
+ control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
504
+ return control
505
+
506
+ class T2IAdapter(ControlBase):
507
+ def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
508
+ super().__init__(device)
509
+ self.t2i_model = t2i_model
510
+ self.channels_in = channels_in
511
+ self.control_input = None
512
+ self.compression_ratio = compression_ratio
513
+ self.upscale_algorithm = upscale_algorithm
514
+
515
+ def scale_image_to(self, width, height):
516
+ unshuffle_amount = self.t2i_model.unshuffle_amount
517
+ width = math.ceil(width / unshuffle_amount) * unshuffle_amount
518
+ height = math.ceil(height / unshuffle_amount) * unshuffle_amount
519
+ return width, height
520
+
521
+ def get_control(self, x_noisy, t, cond, batched_number):
522
+ control_prev = None
523
+ if self.previous_controlnet is not None:
524
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
525
+
526
+ if self.timestep_range is not None:
527
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
528
+ if control_prev is not None:
529
+ return control_prev
530
+ else:
531
+ return None
532
+
533
+ if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
534
+ if self.cond_hint is not None:
535
+ del self.cond_hint
536
+ self.control_input = None
537
+ self.cond_hint = None
538
+ width, height = self.scale_image_to(x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio)
539
+ self.cond_hint = totoro.utils.common_upscale(self.cond_hint_original, width, height, self.upscale_algorithm, "center").float().to(self.device)
540
+ if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
541
+ self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
542
+ if x_noisy.shape[0] != self.cond_hint.shape[0]:
543
+ self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
544
+ if self.control_input is None:
545
+ self.t2i_model.to(x_noisy.dtype)
546
+ self.t2i_model.to(self.device)
547
+ self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
548
+ self.t2i_model.cpu()
549
+
550
+ control_input = {}
551
+ for k in self.control_input:
552
+ control_input[k] = list(map(lambda a: None if a is None else a.clone(), self.control_input[k]))
553
+
554
+ return self.control_merge(control_input, control_prev, x_noisy.dtype)
555
+
556
+ def copy(self):
557
+ c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)
558
+ self.copy_to(c)
559
+ return c
560
+
561
+ def load_t2i_adapter(t2i_data):
562
+ compression_ratio = 8
563
+ upscale_algorithm = 'nearest-exact'
564
+
565
+ if 'adapter' in t2i_data:
566
+ t2i_data = t2i_data['adapter']
567
+ if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
568
+ prefix_replace = {}
569
+ for i in range(4):
570
+ for j in range(2):
571
+ prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
572
+ prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
573
+ prefix_replace["adapter."] = ""
574
+ t2i_data = totoro.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
575
+ keys = t2i_data.keys()
576
+
577
+ if "body.0.in_conv.weight" in keys:
578
+ cin = t2i_data['body.0.in_conv.weight'].shape[1]
579
+ model_ad = totoro.t2i_adapter.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
580
+ elif 'conv_in.weight' in keys:
581
+ cin = t2i_data['conv_in.weight'].shape[1]
582
+ channel = t2i_data['conv_in.weight'].shape[0]
583
+ ksize = t2i_data['body.0.block2.weight'].shape[2]
584
+ use_conv = False
585
+ down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
586
+ if len(down_opts) > 0:
587
+ use_conv = True
588
+ xl = False
589
+ if cin == 256 or cin == 768:
590
+ xl = True
591
+ model_ad = totoro.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
592
+ elif "backbone.0.0.weight" in keys:
593
+ model_ad = totoro.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
594
+ compression_ratio = 32
595
+ upscale_algorithm = 'bilinear'
596
+ elif "backbone.10.blocks.0.weight" in keys:
597
+ model_ad = totoro.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
598
+ compression_ratio = 1
599
+ upscale_algorithm = 'nearest-exact'
600
+ else:
601
+ return None
602
+
603
+ missing, unexpected = model_ad.load_state_dict(t2i_data)
604
+ if len(missing) > 0:
605
+ logging.warning("t2i missing {}".format(missing))
606
+
607
+ if len(unexpected) > 0:
608
+ logging.debug("t2i unexpected {}".format(unexpected))
609
+
610
+ return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)
content/flux/totoro/diffusers_convert.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import logging
4
+
5
+ # conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
6
+
7
+ # =================#
8
+ # UNet Conversion #
9
+ # =================#
10
+
11
+ unet_conversion_map = [
12
+ # (stable-diffusion, HF Diffusers)
13
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
14
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
15
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
16
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
17
+ ("input_blocks.0.0.weight", "conv_in.weight"),
18
+ ("input_blocks.0.0.bias", "conv_in.bias"),
19
+ ("out.0.weight", "conv_norm_out.weight"),
20
+ ("out.0.bias", "conv_norm_out.bias"),
21
+ ("out.2.weight", "conv_out.weight"),
22
+ ("out.2.bias", "conv_out.bias"),
23
+ ]
24
+
25
+ unet_conversion_map_resnet = [
26
+ # (stable-diffusion, HF Diffusers)
27
+ ("in_layers.0", "norm1"),
28
+ ("in_layers.2", "conv1"),
29
+ ("out_layers.0", "norm2"),
30
+ ("out_layers.3", "conv2"),
31
+ ("emb_layers.1", "time_emb_proj"),
32
+ ("skip_connection", "conv_shortcut"),
33
+ ]
34
+
35
+ unet_conversion_map_layer = []
36
+ # hardcoded number of downblocks and resnets/attentions...
37
+ # would need smarter logic for other networks.
38
+ for i in range(4):
39
+ # loop over downblocks/upblocks
40
+
41
+ for j in range(2):
42
+ # loop over resnets/attentions for downblocks
43
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
44
+ sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
45
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
46
+
47
+ if i < 3:
48
+ # no attention layers in down_blocks.3
49
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
50
+ sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
51
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
52
+
53
+ for j in range(3):
54
+ # loop over resnets/attentions for upblocks
55
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
56
+ sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
57
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
58
+
59
+ if i > 0:
60
+ # no attention layers in up_blocks.0
61
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
62
+ sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
63
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
64
+
65
+ if i < 3:
66
+ # no downsample in down_blocks.3
67
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
68
+ sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
69
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
70
+
71
+ # no upsample in up_blocks.3
72
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
73
+ sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
74
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
75
+
76
+ hf_mid_atn_prefix = "mid_block.attentions.0."
77
+ sd_mid_atn_prefix = "middle_block.1."
78
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
79
+
80
+ for j in range(2):
81
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
82
+ sd_mid_res_prefix = f"middle_block.{2 * j}."
83
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
84
+
85
+
86
+ def convert_unet_state_dict(unet_state_dict):
87
+ # buyer beware: this is a *brittle* function,
88
+ # and correct output requires that all of these pieces interact in
89
+ # the exact order in which I have arranged them.
90
+ mapping = {k: k for k in unet_state_dict.keys()}
91
+ for sd_name, hf_name in unet_conversion_map:
92
+ mapping[hf_name] = sd_name
93
+ for k, v in mapping.items():
94
+ if "resnets" in k:
95
+ for sd_part, hf_part in unet_conversion_map_resnet:
96
+ v = v.replace(hf_part, sd_part)
97
+ mapping[k] = v
98
+ for k, v in mapping.items():
99
+ for sd_part, hf_part in unet_conversion_map_layer:
100
+ v = v.replace(hf_part, sd_part)
101
+ mapping[k] = v
102
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
103
+ return new_state_dict
104
+
105
+
106
+ # ================#
107
+ # VAE Conversion #
108
+ # ================#
109
+
110
+ vae_conversion_map = [
111
+ # (stable-diffusion, HF Diffusers)
112
+ ("nin_shortcut", "conv_shortcut"),
113
+ ("norm_out", "conv_norm_out"),
114
+ ("mid.attn_1.", "mid_block.attentions.0."),
115
+ ]
116
+
117
+ for i in range(4):
118
+ # down_blocks have two resnets
119
+ for j in range(2):
120
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
121
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
122
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
123
+
124
+ if i < 3:
125
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
126
+ sd_downsample_prefix = f"down.{i}.downsample."
127
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
128
+
129
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
130
+ sd_upsample_prefix = f"up.{3 - i}.upsample."
131
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
132
+
133
+ # up_blocks have three resnets
134
+ # also, up blocks in hf are numbered in reverse from sd
135
+ for j in range(3):
136
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
137
+ sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
138
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
139
+
140
+ # this part accounts for mid blocks in both the encoder and the decoder
141
+ for i in range(2):
142
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
143
+ sd_mid_res_prefix = f"mid.block_{i + 1}."
144
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
145
+
146
+ vae_conversion_map_attn = [
147
+ # (stable-diffusion, HF Diffusers)
148
+ ("norm.", "group_norm."),
149
+ ("q.", "query."),
150
+ ("k.", "key."),
151
+ ("v.", "value."),
152
+ ("q.", "to_q."),
153
+ ("k.", "to_k."),
154
+ ("v.", "to_v."),
155
+ ("proj_out.", "to_out.0."),
156
+ ("proj_out.", "proj_attn."),
157
+ ]
158
+
159
+
160
+ def reshape_weight_for_sd(w):
161
+ # convert HF linear weights to SD conv2d weights
162
+ return w.reshape(*w.shape, 1, 1)
163
+
164
+
165
+ def convert_vae_state_dict(vae_state_dict):
166
+ mapping = {k: k for k in vae_state_dict.keys()}
167
+ for k, v in mapping.items():
168
+ for sd_part, hf_part in vae_conversion_map:
169
+ v = v.replace(hf_part, sd_part)
170
+ mapping[k] = v
171
+ for k, v in mapping.items():
172
+ if "attentions" in k:
173
+ for sd_part, hf_part in vae_conversion_map_attn:
174
+ v = v.replace(hf_part, sd_part)
175
+ mapping[k] = v
176
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
177
+ weights_to_convert = ["q", "k", "v", "proj_out"]
178
+ for k, v in new_state_dict.items():
179
+ for weight_name in weights_to_convert:
180
+ if f"mid.attn_1.{weight_name}.weight" in k:
181
+ logging.debug(f"Reshaping {k} for SD format")
182
+ new_state_dict[k] = reshape_weight_for_sd(v)
183
+ return new_state_dict
184
+
185
+
186
+ # =========================#
187
+ # Text Encoder Conversion #
188
+ # =========================#
189
+
190
+
191
+ textenc_conversion_lst = [
192
+ # (stable-diffusion, HF Diffusers)
193
+ ("resblocks.", "text_model.encoder.layers."),
194
+ ("ln_1", "layer_norm1"),
195
+ ("ln_2", "layer_norm2"),
196
+ (".c_fc.", ".fc1."),
197
+ (".c_proj.", ".fc2."),
198
+ (".attn", ".self_attn"),
199
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
200
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
201
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
202
+ ]
203
+ protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
204
+ textenc_pattern = re.compile("|".join(protected.keys()))
205
+
206
+ # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
207
+ code2idx = {"q": 0, "k": 1, "v": 2}
208
+
209
+ # This function exists because at the time of writing torch.cat can't do fp8 with cuda
210
+ def cat_tensors(tensors):
211
+ x = 0
212
+ for t in tensors:
213
+ x += t.shape[0]
214
+
215
+ shape = [x] + list(tensors[0].shape)[1:]
216
+ out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
217
+
218
+ x = 0
219
+ for t in tensors:
220
+ out[x:x + t.shape[0]] = t
221
+ x += t.shape[0]
222
+
223
+ return out
224
+
225
+ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
226
+ new_state_dict = {}
227
+ capture_qkv_weight = {}
228
+ capture_qkv_bias = {}
229
+ for k, v in text_enc_dict.items():
230
+ if not k.startswith(prefix):
231
+ continue
232
+ if (
233
+ k.endswith(".self_attn.q_proj.weight")
234
+ or k.endswith(".self_attn.k_proj.weight")
235
+ or k.endswith(".self_attn.v_proj.weight")
236
+ ):
237
+ k_pre = k[: -len(".q_proj.weight")]
238
+ k_code = k[-len("q_proj.weight")]
239
+ if k_pre not in capture_qkv_weight:
240
+ capture_qkv_weight[k_pre] = [None, None, None]
241
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
242
+ continue
243
+
244
+ if (
245
+ k.endswith(".self_attn.q_proj.bias")
246
+ or k.endswith(".self_attn.k_proj.bias")
247
+ or k.endswith(".self_attn.v_proj.bias")
248
+ ):
249
+ k_pre = k[: -len(".q_proj.bias")]
250
+ k_code = k[-len("q_proj.bias")]
251
+ if k_pre not in capture_qkv_bias:
252
+ capture_qkv_bias[k_pre] = [None, None, None]
253
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
254
+ continue
255
+
256
+ text_proj = "transformer.text_projection.weight"
257
+ if k.endswith(text_proj):
258
+ new_state_dict[k.replace(text_proj, "text_projection")] = v.transpose(0, 1).contiguous()
259
+ else:
260
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
261
+ new_state_dict[relabelled_key] = v
262
+
263
+ for k_pre, tensors in capture_qkv_weight.items():
264
+ if None in tensors:
265
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
266
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
267
+ new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors)
268
+
269
+ for k_pre, tensors in capture_qkv_bias.items():
270
+ if None in tensors:
271
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
272
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
273
+ new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors)
274
+
275
+ return new_state_dict
276
+
277
+
278
+ def convert_text_enc_state_dict(text_enc_dict):
279
+ return text_enc_dict
280
+
281
+
content/flux/totoro/diffusers_load.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import totoro.sd
4
+
5
+ def first_file(path, filenames):
6
+ for f in filenames:
7
+ p = os.path.join(path, f)
8
+ if os.path.exists(p):
9
+ return p
10
+ return None
11
+
12
+ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None):
13
+ diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"]
14
+ unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names)
15
+ vae_path = first_file(os.path.join(model_path, "vae"), diffusion_model_names)
16
+
17
+ text_encoder_model_names = ["model.fp16.safetensors", "model.safetensors", "pytorch_model.fp16.bin", "pytorch_model.bin"]
18
+ text_encoder1_path = first_file(os.path.join(model_path, "text_encoder"), text_encoder_model_names)
19
+ text_encoder2_path = first_file(os.path.join(model_path, "text_encoder_2"), text_encoder_model_names)
20
+
21
+ text_encoder_paths = [text_encoder1_path]
22
+ if text_encoder2_path is not None:
23
+ text_encoder_paths.append(text_encoder2_path)
24
+
25
+ unet = totoro.sd.load_unet(unet_path)
26
+
27
+ clip = None
28
+ if output_clip:
29
+ clip = totoro.sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory)
30
+
31
+ vae = None
32
+ if output_vae:
33
+ sd = totoro.utils.load_torch_file(vae_path)
34
+ vae = totoro.sd.VAE(sd=sd)
35
+
36
+ return (unet, clip, vae)
content/flux/totoro/extra_samplers/uni_pc.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #code taken from: https://github.com/wl-zhao/UniPC and modified
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import math
6
+
7
+ from tqdm.auto import trange, tqdm
8
+
9
+
10
+ class NoiseScheduleVP:
11
+ def __init__(
12
+ self,
13
+ schedule='discrete',
14
+ betas=None,
15
+ alphas_cumprod=None,
16
+ continuous_beta_0=0.1,
17
+ continuous_beta_1=20.,
18
+ ):
19
+ """Create a wrapper class for the forward SDE (VP type).
20
+
21
+ ***
22
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
23
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
24
+ ***
25
+
26
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
27
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
28
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
29
+
30
+ log_alpha_t = self.marginal_log_mean_coeff(t)
31
+ sigma_t = self.marginal_std(t)
32
+ lambda_t = self.marginal_lambda(t)
33
+
34
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
35
+
36
+ t = self.inverse_lambda(lambda_t)
37
+
38
+ ===============================================================
39
+
40
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
41
+
42
+ 1. For discrete-time DPMs:
43
+
44
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
45
+ t_i = (i + 1) / N
46
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
47
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
48
+
49
+ Args:
50
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
51
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
52
+
53
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
54
+
55
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
56
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
57
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
58
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
59
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
60
+ and
61
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
62
+
63
+
64
+ 2. For continuous-time DPMs:
65
+
66
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
67
+ schedule are the default settings in DDPM and improved-DDPM:
68
+
69
+ Args:
70
+ beta_min: A `float` number. The smallest beta for the linear schedule.
71
+ beta_max: A `float` number. The largest beta for the linear schedule.
72
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
73
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
74
+ T: A `float` number. The ending time of the forward process.
75
+
76
+ ===============================================================
77
+
78
+ Args:
79
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
80
+ 'linear' or 'cosine' for continuous-time DPMs.
81
+ Returns:
82
+ A wrapper object of the forward SDE (VP type).
83
+
84
+ ===============================================================
85
+
86
+ Example:
87
+
88
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
89
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
90
+
91
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
92
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
93
+
94
+ # For continuous-time DPMs (VPSDE), linear schedule:
95
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
96
+
97
+ """
98
+
99
+ if schedule not in ['discrete', 'linear', 'cosine']:
100
+ raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
101
+
102
+ self.schedule = schedule
103
+ if schedule == 'discrete':
104
+ if betas is not None:
105
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
106
+ else:
107
+ assert alphas_cumprod is not None
108
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
109
+ self.total_N = len(log_alphas)
110
+ self.T = 1.
111
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
112
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
113
+ else:
114
+ self.total_N = 1000
115
+ self.beta_0 = continuous_beta_0
116
+ self.beta_1 = continuous_beta_1
117
+ self.cosine_s = 0.008
118
+ self.cosine_beta_max = 999.
119
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
120
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
121
+ self.schedule = schedule
122
+ if schedule == 'cosine':
123
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
124
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
125
+ self.T = 0.9946
126
+ else:
127
+ self.T = 1.
128
+
129
+ def marginal_log_mean_coeff(self, t):
130
+ """
131
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
132
+ """
133
+ if self.schedule == 'discrete':
134
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
135
+ elif self.schedule == 'linear':
136
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
137
+ elif self.schedule == 'cosine':
138
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
139
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
140
+ return log_alpha_t
141
+
142
+ def marginal_alpha(self, t):
143
+ """
144
+ Compute alpha_t of a given continuous-time label t in [0, T].
145
+ """
146
+ return torch.exp(self.marginal_log_mean_coeff(t))
147
+
148
+ def marginal_std(self, t):
149
+ """
150
+ Compute sigma_t of a given continuous-time label t in [0, T].
151
+ """
152
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
153
+
154
+ def marginal_lambda(self, t):
155
+ """
156
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
157
+ """
158
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
159
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
160
+ return log_mean_coeff - log_std
161
+
162
+ def inverse_lambda(self, lamb):
163
+ """
164
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
165
+ """
166
+ if self.schedule == 'linear':
167
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
168
+ Delta = self.beta_0**2 + tmp
169
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
170
+ elif self.schedule == 'discrete':
171
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
172
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
173
+ return t.reshape((-1,))
174
+ else:
175
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
176
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
177
+ t = t_fn(log_alpha)
178
+ return t
179
+
180
+
181
+ def model_wrapper(
182
+ model,
183
+ noise_schedule,
184
+ model_type="noise",
185
+ model_kwargs={},
186
+ guidance_type="uncond",
187
+ condition=None,
188
+ unconditional_condition=None,
189
+ guidance_scale=1.,
190
+ classifier_fn=None,
191
+ classifier_kwargs={},
192
+ ):
193
+ """Create a wrapper function for the noise prediction model.
194
+
195
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
196
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
197
+
198
+ We support four types of the diffusion model by setting `model_type`:
199
+
200
+ 1. "noise": noise prediction model. (Trained by predicting noise).
201
+
202
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
203
+
204
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
205
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
206
+
207
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
208
+ arXiv preprint arXiv:2202.00512 (2022).
209
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
210
+ arXiv preprint arXiv:2210.02303 (2022).
211
+
212
+ 4. "score": marginal score function. (Trained by denoising score matching).
213
+ Note that the score function and the noise prediction model follows a simple relationship:
214
+ ```
215
+ noise(x_t, t) = -sigma_t * score(x_t, t)
216
+ ```
217
+
218
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
219
+ 1. "uncond": unconditional sampling by DPMs.
220
+ The input `model` has the following format:
221
+ ``
222
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
223
+ ``
224
+
225
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
226
+ The input `model` has the following format:
227
+ ``
228
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
229
+ ``
230
+
231
+ The input `classifier_fn` has the following format:
232
+ ``
233
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
234
+ ``
235
+
236
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
237
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
238
+
239
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
240
+ The input `model` has the following format:
241
+ ``
242
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
243
+ ``
244
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
245
+
246
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
247
+ arXiv preprint arXiv:2207.12598 (2022).
248
+
249
+
250
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
251
+ or continuous-time labels (i.e. epsilon to T).
252
+
253
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
254
+ ``
255
+ def model_fn(x, t_continuous) -> noise:
256
+ t_input = get_model_input_time(t_continuous)
257
+ return noise_pred(model, x, t_input, **model_kwargs)
258
+ ``
259
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
260
+
261
+ ===============================================================
262
+
263
+ Args:
264
+ model: A diffusion model with the corresponding format described above.
265
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
266
+ model_type: A `str`. The parameterization type of the diffusion model.
267
+ "noise" or "x_start" or "v" or "score".
268
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
269
+ guidance_type: A `str`. The type of the guidance for sampling.
270
+ "uncond" or "classifier" or "classifier-free".
271
+ condition: A pytorch tensor. The condition for the guided sampling.
272
+ Only used for "classifier" or "classifier-free" guidance type.
273
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
274
+ Only used for "classifier-free" guidance type.
275
+ guidance_scale: A `float`. The scale for the guided sampling.
276
+ classifier_fn: A classifier function. Only used for the classifier guidance.
277
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
278
+ Returns:
279
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
280
+ """
281
+
282
+ def get_model_input_time(t_continuous):
283
+ """
284
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
285
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
286
+ For continuous-time DPMs, we just use `t_continuous`.
287
+ """
288
+ if noise_schedule.schedule == 'discrete':
289
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
290
+ else:
291
+ return t_continuous
292
+
293
+ def noise_pred_fn(x, t_continuous, cond=None):
294
+ if t_continuous.reshape((-1,)).shape[0] == 1:
295
+ t_continuous = t_continuous.expand((x.shape[0]))
296
+ t_input = get_model_input_time(t_continuous)
297
+ output = model(x, t_input, **model_kwargs)
298
+ if model_type == "noise":
299
+ return output
300
+ elif model_type == "x_start":
301
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
302
+ dims = x.dim()
303
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
304
+ elif model_type == "v":
305
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
306
+ dims = x.dim()
307
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
308
+ elif model_type == "score":
309
+ sigma_t = noise_schedule.marginal_std(t_continuous)
310
+ dims = x.dim()
311
+ return -expand_dims(sigma_t, dims) * output
312
+
313
+ def cond_grad_fn(x, t_input):
314
+ """
315
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
316
+ """
317
+ with torch.enable_grad():
318
+ x_in = x.detach().requires_grad_(True)
319
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
320
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
321
+
322
+ def model_fn(x, t_continuous):
323
+ """
324
+ The noise predicition model function that is used for DPM-Solver.
325
+ """
326
+ if t_continuous.reshape((-1,)).shape[0] == 1:
327
+ t_continuous = t_continuous.expand((x.shape[0]))
328
+ if guidance_type == "uncond":
329
+ return noise_pred_fn(x, t_continuous)
330
+ elif guidance_type == "classifier":
331
+ assert classifier_fn is not None
332
+ t_input = get_model_input_time(t_continuous)
333
+ cond_grad = cond_grad_fn(x, t_input)
334
+ sigma_t = noise_schedule.marginal_std(t_continuous)
335
+ noise = noise_pred_fn(x, t_continuous)
336
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
337
+ elif guidance_type == "classifier-free":
338
+ if guidance_scale == 1. or unconditional_condition is None:
339
+ return noise_pred_fn(x, t_continuous, cond=condition)
340
+ else:
341
+ x_in = torch.cat([x] * 2)
342
+ t_in = torch.cat([t_continuous] * 2)
343
+ c_in = torch.cat([unconditional_condition, condition])
344
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
345
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
346
+
347
+ assert model_type in ["noise", "x_start", "v"]
348
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
349
+ return model_fn
350
+
351
+
352
+ class UniPC:
353
+ def __init__(
354
+ self,
355
+ model_fn,
356
+ noise_schedule,
357
+ predict_x0=True,
358
+ thresholding=False,
359
+ max_val=1.,
360
+ variant='bh1',
361
+ ):
362
+ """Construct a UniPC.
363
+
364
+ We support both data_prediction and noise_prediction.
365
+ """
366
+ self.model = model_fn
367
+ self.noise_schedule = noise_schedule
368
+ self.variant = variant
369
+ self.predict_x0 = predict_x0
370
+ self.thresholding = thresholding
371
+ self.max_val = max_val
372
+
373
+ def dynamic_thresholding_fn(self, x0, t=None):
374
+ """
375
+ The dynamic thresholding method.
376
+ """
377
+ dims = x0.dim()
378
+ p = self.dynamic_thresholding_ratio
379
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
380
+ s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
381
+ x0 = torch.clamp(x0, -s, s) / s
382
+ return x0
383
+
384
+ def noise_prediction_fn(self, x, t):
385
+ """
386
+ Return the noise prediction model.
387
+ """
388
+ return self.model(x, t)
389
+
390
+ def data_prediction_fn(self, x, t):
391
+ """
392
+ Return the data prediction model (with thresholding).
393
+ """
394
+ noise = self.noise_prediction_fn(x, t)
395
+ dims = x.dim()
396
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
397
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
398
+ if self.thresholding:
399
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
400
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
401
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
402
+ x0 = torch.clamp(x0, -s, s) / s
403
+ return x0
404
+
405
+ def model_fn(self, x, t):
406
+ """
407
+ Convert the model to the noise prediction model or the data prediction model.
408
+ """
409
+ if self.predict_x0:
410
+ return self.data_prediction_fn(x, t)
411
+ else:
412
+ return self.noise_prediction_fn(x, t)
413
+
414
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
415
+ """Compute the intermediate time steps for sampling.
416
+ """
417
+ if skip_type == 'logSNR':
418
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
419
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
420
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
421
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
422
+ elif skip_type == 'time_uniform':
423
+ return torch.linspace(t_T, t_0, N + 1).to(device)
424
+ elif skip_type == 'time_quadratic':
425
+ t_order = 2
426
+ t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
427
+ return t
428
+ else:
429
+ raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
430
+
431
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
432
+ """
433
+ Get the order of each step for sampling by the singlestep DPM-Solver.
434
+ """
435
+ if order == 3:
436
+ K = steps // 3 + 1
437
+ if steps % 3 == 0:
438
+ orders = [3,] * (K - 2) + [2, 1]
439
+ elif steps % 3 == 1:
440
+ orders = [3,] * (K - 1) + [1]
441
+ else:
442
+ orders = [3,] * (K - 1) + [2]
443
+ elif order == 2:
444
+ if steps % 2 == 0:
445
+ K = steps // 2
446
+ orders = [2,] * K
447
+ else:
448
+ K = steps // 2 + 1
449
+ orders = [2,] * (K - 1) + [1]
450
+ elif order == 1:
451
+ K = steps
452
+ orders = [1,] * steps
453
+ else:
454
+ raise ValueError("'order' must be '1' or '2' or '3'.")
455
+ if skip_type == 'logSNR':
456
+ # To reproduce the results in DPM-Solver paper
457
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
458
+ else:
459
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
460
+ return timesteps_outer, orders
461
+
462
+ def denoise_to_zero_fn(self, x, s):
463
+ """
464
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
465
+ """
466
+ return self.data_prediction_fn(x, s)
467
+
468
+ def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
469
+ if len(t.shape) == 0:
470
+ t = t.view(-1)
471
+ if 'bh' in self.variant:
472
+ return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
473
+ else:
474
+ assert self.variant == 'vary_coeff'
475
+ return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
476
+
477
+ def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
478
+ print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
479
+ ns = self.noise_schedule
480
+ assert order <= len(model_prev_list)
481
+
482
+ # first compute rks
483
+ t_prev_0 = t_prev_list[-1]
484
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
485
+ lambda_t = ns.marginal_lambda(t)
486
+ model_prev_0 = model_prev_list[-1]
487
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
488
+ log_alpha_t = ns.marginal_log_mean_coeff(t)
489
+ alpha_t = torch.exp(log_alpha_t)
490
+
491
+ h = lambda_t - lambda_prev_0
492
+
493
+ rks = []
494
+ D1s = []
495
+ for i in range(1, order):
496
+ t_prev_i = t_prev_list[-(i + 1)]
497
+ model_prev_i = model_prev_list[-(i + 1)]
498
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
499
+ rk = (lambda_prev_i - lambda_prev_0) / h
500
+ rks.append(rk)
501
+ D1s.append((model_prev_i - model_prev_0) / rk)
502
+
503
+ rks.append(1.)
504
+ rks = torch.tensor(rks, device=x.device)
505
+
506
+ K = len(rks)
507
+ # build C matrix
508
+ C = []
509
+
510
+ col = torch.ones_like(rks)
511
+ for k in range(1, K + 1):
512
+ C.append(col)
513
+ col = col * rks / (k + 1)
514
+ C = torch.stack(C, dim=1)
515
+
516
+ if len(D1s) > 0:
517
+ D1s = torch.stack(D1s, dim=1) # (B, K)
518
+ C_inv_p = torch.linalg.inv(C[:-1, :-1])
519
+ A_p = C_inv_p
520
+
521
+ if use_corrector:
522
+ print('using corrector')
523
+ C_inv = torch.linalg.inv(C)
524
+ A_c = C_inv
525
+
526
+ hh = -h if self.predict_x0 else h
527
+ h_phi_1 = torch.expm1(hh)
528
+ h_phi_ks = []
529
+ factorial_k = 1
530
+ h_phi_k = h_phi_1
531
+ for k in range(1, K + 2):
532
+ h_phi_ks.append(h_phi_k)
533
+ h_phi_k = h_phi_k / hh - 1 / factorial_k
534
+ factorial_k *= (k + 1)
535
+
536
+ model_t = None
537
+ if self.predict_x0:
538
+ x_t_ = (
539
+ sigma_t / sigma_prev_0 * x
540
+ - alpha_t * h_phi_1 * model_prev_0
541
+ )
542
+ # now predictor
543
+ x_t = x_t_
544
+ if len(D1s) > 0:
545
+ # compute the residuals for predictor
546
+ for k in range(K - 1):
547
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
548
+ # now corrector
549
+ if use_corrector:
550
+ model_t = self.model_fn(x_t, t)
551
+ D1_t = (model_t - model_prev_0)
552
+ x_t = x_t_
553
+ k = 0
554
+ for k in range(K - 1):
555
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
556
+ x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
557
+ else:
558
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
559
+ x_t_ = (
560
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
561
+ - (sigma_t * h_phi_1) * model_prev_0
562
+ )
563
+ # now predictor
564
+ x_t = x_t_
565
+ if len(D1s) > 0:
566
+ # compute the residuals for predictor
567
+ for k in range(K - 1):
568
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
569
+ # now corrector
570
+ if use_corrector:
571
+ model_t = self.model_fn(x_t, t)
572
+ D1_t = (model_t - model_prev_0)
573
+ x_t = x_t_
574
+ k = 0
575
+ for k in range(K - 1):
576
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
577
+ x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
578
+ return x_t, model_t
579
+
580
+ def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
581
+ # print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
582
+ ns = self.noise_schedule
583
+ assert order <= len(model_prev_list)
584
+ dims = x.dim()
585
+
586
+ # first compute rks
587
+ t_prev_0 = t_prev_list[-1]
588
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
589
+ lambda_t = ns.marginal_lambda(t)
590
+ model_prev_0 = model_prev_list[-1]
591
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
592
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
593
+ alpha_t = torch.exp(log_alpha_t)
594
+
595
+ h = lambda_t - lambda_prev_0
596
+
597
+ rks = []
598
+ D1s = []
599
+ for i in range(1, order):
600
+ t_prev_i = t_prev_list[-(i + 1)]
601
+ model_prev_i = model_prev_list[-(i + 1)]
602
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
603
+ rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
604
+ rks.append(rk)
605
+ D1s.append((model_prev_i - model_prev_0) / rk)
606
+
607
+ rks.append(1.)
608
+ rks = torch.tensor(rks, device=x.device)
609
+
610
+ R = []
611
+ b = []
612
+
613
+ hh = -h[0] if self.predict_x0 else h[0]
614
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
615
+ h_phi_k = h_phi_1 / hh - 1
616
+
617
+ factorial_i = 1
618
+
619
+ if self.variant == 'bh1':
620
+ B_h = hh
621
+ elif self.variant == 'bh2':
622
+ B_h = torch.expm1(hh)
623
+ else:
624
+ raise NotImplementedError()
625
+
626
+ for i in range(1, order + 1):
627
+ R.append(torch.pow(rks, i - 1))
628
+ b.append(h_phi_k * factorial_i / B_h)
629
+ factorial_i *= (i + 1)
630
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
631
+
632
+ R = torch.stack(R)
633
+ b = torch.tensor(b, device=x.device)
634
+
635
+ # now predictor
636
+ use_predictor = len(D1s) > 0 and x_t is None
637
+ if len(D1s) > 0:
638
+ D1s = torch.stack(D1s, dim=1) # (B, K)
639
+ if x_t is None:
640
+ # for order 2, we use a simplified version
641
+ if order == 2:
642
+ rhos_p = torch.tensor([0.5], device=b.device)
643
+ else:
644
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
645
+ else:
646
+ D1s = None
647
+
648
+ if use_corrector:
649
+ # print('using corrector')
650
+ # for order 1, we use a simplified version
651
+ if order == 1:
652
+ rhos_c = torch.tensor([0.5], device=b.device)
653
+ else:
654
+ rhos_c = torch.linalg.solve(R, b)
655
+
656
+ model_t = None
657
+ if self.predict_x0:
658
+ x_t_ = (
659
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
660
+ - expand_dims(alpha_t * h_phi_1, dims)* model_prev_0
661
+ )
662
+
663
+ if x_t is None:
664
+ if use_predictor:
665
+ pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
666
+ else:
667
+ pred_res = 0
668
+ x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
669
+
670
+ if use_corrector:
671
+ model_t = self.model_fn(x_t, t)
672
+ if D1s is not None:
673
+ corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
674
+ else:
675
+ corr_res = 0
676
+ D1_t = (model_t - model_prev_0)
677
+ x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
678
+ else:
679
+ x_t_ = (
680
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
681
+ - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
682
+ )
683
+ if x_t is None:
684
+ if use_predictor:
685
+ pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
686
+ else:
687
+ pred_res = 0
688
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
689
+
690
+ if use_corrector:
691
+ model_t = self.model_fn(x_t, t)
692
+ if D1s is not None:
693
+ corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
694
+ else:
695
+ corr_res = 0
696
+ D1_t = (model_t - model_prev_0)
697
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
698
+ return x_t, model_t
699
+
700
+
701
+ def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform',
702
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
703
+ atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
704
+ ):
705
+ # t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
706
+ # t_T = self.noise_schedule.T if t_start is None else t_start
707
+ device = x.device
708
+ steps = len(timesteps) - 1
709
+ if method == 'multistep':
710
+ assert steps >= order
711
+ # timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
712
+ assert timesteps.shape[0] - 1 == steps
713
+ # with torch.no_grad():
714
+ for step_index in trange(steps, disable=disable_pbar):
715
+ if step_index == 0:
716
+ vec_t = timesteps[0].expand((x.shape[0]))
717
+ model_prev_list = [self.model_fn(x, vec_t)]
718
+ t_prev_list = [vec_t]
719
+ elif step_index < order:
720
+ init_order = step_index
721
+ # Init the first `order` values by lower order multistep DPM-Solver.
722
+ # for init_order in range(1, order):
723
+ vec_t = timesteps[init_order].expand(x.shape[0])
724
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
725
+ if model_x is None:
726
+ model_x = self.model_fn(x, vec_t)
727
+ model_prev_list.append(model_x)
728
+ t_prev_list.append(vec_t)
729
+ else:
730
+ extra_final_step = 0
731
+ if step_index == (steps - 1):
732
+ extra_final_step = 1
733
+ for step in range(step_index, step_index + 1 + extra_final_step):
734
+ vec_t = timesteps[step].expand(x.shape[0])
735
+ if lower_order_final:
736
+ step_order = min(order, steps + 1 - step)
737
+ else:
738
+ step_order = order
739
+ # print('this step order:', step_order)
740
+ if step == steps:
741
+ # print('do not run corrector at the last step')
742
+ use_corrector = False
743
+ else:
744
+ use_corrector = True
745
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
746
+ for i in range(order - 1):
747
+ t_prev_list[i] = t_prev_list[i + 1]
748
+ model_prev_list[i] = model_prev_list[i + 1]
749
+ t_prev_list[-1] = vec_t
750
+ # We do not need to evaluate the final model value.
751
+ if step < steps:
752
+ if model_x is None:
753
+ model_x = self.model_fn(x, vec_t)
754
+ model_prev_list[-1] = model_x
755
+ if callback is not None:
756
+ callback({'x': x, 'i': step_index, 'denoised': model_prev_list[-1]})
757
+ else:
758
+ raise NotImplementedError()
759
+ # if denoise_to_zero:
760
+ # x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
761
+ return x
762
+
763
+
764
+ #############################################################
765
+ # other utility functions
766
+ #############################################################
767
+
768
+ def interpolate_fn(x, xp, yp):
769
+ """
770
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
771
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
772
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
773
+
774
+ Args:
775
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
776
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
777
+ yp: PyTorch tensor with shape [C, K].
778
+ Returns:
779
+ The function values f(x), with shape [N, C].
780
+ """
781
+ N, K = x.shape[0], xp.shape[1]
782
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
783
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
784
+ x_idx = torch.argmin(x_indices, dim=2)
785
+ cand_start_idx = x_idx - 1
786
+ start_idx = torch.where(
787
+ torch.eq(x_idx, 0),
788
+ torch.tensor(1, device=x.device),
789
+ torch.where(
790
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
791
+ ),
792
+ )
793
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
794
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
795
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
796
+ start_idx2 = torch.where(
797
+ torch.eq(x_idx, 0),
798
+ torch.tensor(0, device=x.device),
799
+ torch.where(
800
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
801
+ ),
802
+ )
803
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
804
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
805
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
806
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
807
+ return cand
808
+
809
+
810
+ def expand_dims(v, dims):
811
+ """
812
+ Expand the tensor `v` to the dim `dims`.
813
+
814
+ Args:
815
+ `v`: a PyTorch tensor with shape [N].
816
+ `dim`: a `int`.
817
+ Returns:
818
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
819
+ """
820
+ return v[(...,) + (None,)*(dims - 1)]
821
+
822
+
823
+ class SigmaConvert:
824
+ schedule = ""
825
+ def marginal_log_mean_coeff(self, sigma):
826
+ return 0.5 * torch.log(1 / ((sigma * sigma) + 1))
827
+
828
+ def marginal_alpha(self, t):
829
+ return torch.exp(self.marginal_log_mean_coeff(t))
830
+
831
+ def marginal_std(self, t):
832
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
833
+
834
+ def marginal_lambda(self, t):
835
+ """
836
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
837
+ """
838
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
839
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
840
+ return log_mean_coeff - log_std
841
+
842
+ def predict_eps_sigma(model, input, sigma_in, **kwargs):
843
+ sigma = sigma_in.view(sigma_in.shape[:1] + (1,) * (input.ndim - 1))
844
+ input = input * ((sigma ** 2 + 1.0) ** 0.5)
845
+ return (input - model(input, sigma_in, **kwargs)) / sigma
846
+
847
+
848
+ def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
849
+ timesteps = sigmas.clone()
850
+ if sigmas[-1] == 0:
851
+ timesteps = sigmas[:]
852
+ timesteps[-1] = 0.001
853
+ else:
854
+ timesteps = sigmas.clone()
855
+ ns = SigmaConvert()
856
+
857
+ noise = noise / torch.sqrt(1.0 + timesteps[0] ** 2.0)
858
+ model_type = "noise"
859
+
860
+ model_fn = model_wrapper(
861
+ lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs),
862
+ ns,
863
+ model_type=model_type,
864
+ guidance_type="uncond",
865
+ model_kwargs=extra_args,
866
+ )
867
+
868
+ order = min(3, len(timesteps) - 2)
869
+ uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=variant)
870
+ x = uni_pc.sample(noise, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
871
+ x /= ns.marginal_alpha(timesteps[-1])
872
+ return x
873
+
874
+ def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
875
+ return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')
content/flux/totoro/gligen.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from .ldm.modules.attention import CrossAttention
4
+ from inspect import isfunction
5
+ import totoro.ops
6
+ ops = totoro.ops.manual_cast
7
+
8
+ def exists(val):
9
+ return val is not None
10
+
11
+
12
+ def uniq(arr):
13
+ return{el: True for el in arr}.keys()
14
+
15
+
16
+ def default(val, d):
17
+ if exists(val):
18
+ return val
19
+ return d() if isfunction(d) else d
20
+
21
+
22
+ # feedforward
23
+ class GEGLU(nn.Module):
24
+ def __init__(self, dim_in, dim_out):
25
+ super().__init__()
26
+ self.proj = ops.Linear(dim_in, dim_out * 2)
27
+
28
+ def forward(self, x):
29
+ x, gate = self.proj(x).chunk(2, dim=-1)
30
+ return x * torch.nn.functional.gelu(gate)
31
+
32
+
33
+ class FeedForward(nn.Module):
34
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
35
+ super().__init__()
36
+ inner_dim = int(dim * mult)
37
+ dim_out = default(dim_out, dim)
38
+ project_in = nn.Sequential(
39
+ ops.Linear(dim, inner_dim),
40
+ nn.GELU()
41
+ ) if not glu else GEGLU(dim, inner_dim)
42
+
43
+ self.net = nn.Sequential(
44
+ project_in,
45
+ nn.Dropout(dropout),
46
+ ops.Linear(inner_dim, dim_out)
47
+ )
48
+
49
+ def forward(self, x):
50
+ return self.net(x)
51
+
52
+
53
+ class GatedCrossAttentionDense(nn.Module):
54
+ def __init__(self, query_dim, context_dim, n_heads, d_head):
55
+ super().__init__()
56
+
57
+ self.attn = CrossAttention(
58
+ query_dim=query_dim,
59
+ context_dim=context_dim,
60
+ heads=n_heads,
61
+ dim_head=d_head,
62
+ operations=ops)
63
+ self.ff = FeedForward(query_dim, glu=True)
64
+
65
+ self.norm1 = ops.LayerNorm(query_dim)
66
+ self.norm2 = ops.LayerNorm(query_dim)
67
+
68
+ self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
69
+ self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
70
+
71
+ # this can be useful: we can externally change magnitude of tanh(alpha)
72
+ # for example, when it is set to 0, then the entire model is same as
73
+ # original one
74
+ self.scale = 1
75
+
76
+ def forward(self, x, objs):
77
+
78
+ x = x + self.scale * \
79
+ torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs)
80
+ x = x + self.scale * \
81
+ torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
82
+
83
+ return x
84
+
85
+
86
+ class GatedSelfAttentionDense(nn.Module):
87
+ def __init__(self, query_dim, context_dim, n_heads, d_head):
88
+ super().__init__()
89
+
90
+ # we need a linear projection since we need cat visual feature and obj
91
+ # feature
92
+ self.linear = ops.Linear(context_dim, query_dim)
93
+
94
+ self.attn = CrossAttention(
95
+ query_dim=query_dim,
96
+ context_dim=query_dim,
97
+ heads=n_heads,
98
+ dim_head=d_head,
99
+ operations=ops)
100
+ self.ff = FeedForward(query_dim, glu=True)
101
+
102
+ self.norm1 = ops.LayerNorm(query_dim)
103
+ self.norm2 = ops.LayerNorm(query_dim)
104
+
105
+ self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
106
+ self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
107
+
108
+ # this can be useful: we can externally change magnitude of tanh(alpha)
109
+ # for example, when it is set to 0, then the entire model is same as
110
+ # original one
111
+ self.scale = 1
112
+
113
+ def forward(self, x, objs):
114
+
115
+ N_visual = x.shape[1]
116
+ objs = self.linear(objs)
117
+
118
+ x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn(
119
+ self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :]
120
+ x = x + self.scale * \
121
+ torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
122
+
123
+ return x
124
+
125
+
126
+ class GatedSelfAttentionDense2(nn.Module):
127
+ def __init__(self, query_dim, context_dim, n_heads, d_head):
128
+ super().__init__()
129
+
130
+ # we need a linear projection since we need cat visual feature and obj
131
+ # feature
132
+ self.linear = ops.Linear(context_dim, query_dim)
133
+
134
+ self.attn = CrossAttention(
135
+ query_dim=query_dim, context_dim=query_dim, dim_head=d_head, operations=ops)
136
+ self.ff = FeedForward(query_dim, glu=True)
137
+
138
+ self.norm1 = ops.LayerNorm(query_dim)
139
+ self.norm2 = ops.LayerNorm(query_dim)
140
+
141
+ self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
142
+ self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
143
+
144
+ # this can be useful: we can externally change magnitude of tanh(alpha)
145
+ # for example, when it is set to 0, then the entire model is same as
146
+ # original one
147
+ self.scale = 1
148
+
149
+ def forward(self, x, objs):
150
+
151
+ B, N_visual, _ = x.shape
152
+ B, N_ground, _ = objs.shape
153
+
154
+ objs = self.linear(objs)
155
+
156
+ # sanity check
157
+ size_v = math.sqrt(N_visual)
158
+ size_g = math.sqrt(N_ground)
159
+ assert int(size_v) == size_v, "Visual tokens must be square rootable"
160
+ assert int(size_g) == size_g, "Grounding tokens must be square rootable"
161
+ size_v = int(size_v)
162
+ size_g = int(size_g)
163
+
164
+ # select grounding token and resize it to visual token size as residual
165
+ out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[
166
+ :, N_visual:, :]
167
+ out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g)
168
+ out = torch.nn.functional.interpolate(
169
+ out, (size_v, size_v), mode='bicubic')
170
+ residual = out.reshape(B, -1, N_visual).permute(0, 2, 1)
171
+
172
+ # add residual to visual feature
173
+ x = x + self.scale * torch.tanh(self.alpha_attn) * residual
174
+ x = x + self.scale * \
175
+ torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
176
+
177
+ return x
178
+
179
+
180
+ class FourierEmbedder():
181
+ def __init__(self, num_freqs=64, temperature=100):
182
+
183
+ self.num_freqs = num_freqs
184
+ self.temperature = temperature
185
+ self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
186
+
187
+ @torch.no_grad()
188
+ def __call__(self, x, cat_dim=-1):
189
+ "x: arbitrary shape of tensor. dim: cat dim"
190
+ out = []
191
+ for freq in self.freq_bands:
192
+ out.append(torch.sin(freq * x))
193
+ out.append(torch.cos(freq * x))
194
+ return torch.cat(out, cat_dim)
195
+
196
+
197
+ class PositionNet(nn.Module):
198
+ def __init__(self, in_dim, out_dim, fourier_freqs=8):
199
+ super().__init__()
200
+ self.in_dim = in_dim
201
+ self.out_dim = out_dim
202
+
203
+ self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
204
+ self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
205
+
206
+ self.linears = nn.Sequential(
207
+ ops.Linear(self.in_dim + self.position_dim, 512),
208
+ nn.SiLU(),
209
+ ops.Linear(512, 512),
210
+ nn.SiLU(),
211
+ ops.Linear(512, out_dim),
212
+ )
213
+
214
+ self.null_positive_feature = torch.nn.Parameter(
215
+ torch.zeros([self.in_dim]))
216
+ self.null_position_feature = torch.nn.Parameter(
217
+ torch.zeros([self.position_dim]))
218
+
219
+ def forward(self, boxes, masks, positive_embeddings):
220
+ B, N, _ = boxes.shape
221
+ masks = masks.unsqueeze(-1)
222
+ positive_embeddings = positive_embeddings
223
+
224
+ # embedding position (it may includes padding as placeholder)
225
+ xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
226
+
227
+ # learnable null embedding
228
+ positive_null = self.null_positive_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
229
+ xyxy_null = self.null_position_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
230
+
231
+ # replace padding with learnable null embedding
232
+ positive_embeddings = positive_embeddings * \
233
+ masks + (1 - masks) * positive_null
234
+ xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
235
+
236
+ objs = self.linears(
237
+ torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
238
+ assert objs.shape == torch.Size([B, N, self.out_dim])
239
+ return objs
240
+
241
+
242
+ class Gligen(nn.Module):
243
+ def __init__(self, modules, position_net, key_dim):
244
+ super().__init__()
245
+ self.module_list = nn.ModuleList(modules)
246
+ self.position_net = position_net
247
+ self.key_dim = key_dim
248
+ self.max_objs = 30
249
+ self.current_device = torch.device("cpu")
250
+
251
+ def _set_position(self, boxes, masks, positive_embeddings):
252
+ objs = self.position_net(boxes, masks, positive_embeddings)
253
+ def func(x, extra_options):
254
+ key = extra_options["transformer_index"]
255
+ module = self.module_list[key]
256
+ return module(x, objs.to(device=x.device, dtype=x.dtype))
257
+ return func
258
+
259
+ def set_position(self, latent_image_shape, position_params, device):
260
+ batch, c, h, w = latent_image_shape
261
+ masks = torch.zeros([self.max_objs], device="cpu")
262
+ boxes = []
263
+ positive_embeddings = []
264
+ for p in position_params:
265
+ x1 = (p[4]) / w
266
+ y1 = (p[3]) / h
267
+ x2 = (p[4] + p[2]) / w
268
+ y2 = (p[3] + p[1]) / h
269
+ masks[len(boxes)] = 1.0
270
+ boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)]
271
+ positive_embeddings += [p[0]]
272
+ append_boxes = []
273
+ append_conds = []
274
+ if len(boxes) < self.max_objs:
275
+ append_boxes = [torch.zeros(
276
+ [self.max_objs - len(boxes), 4], device="cpu")]
277
+ append_conds = [torch.zeros(
278
+ [self.max_objs - len(boxes), self.key_dim], device="cpu")]
279
+
280
+ box_out = torch.cat(
281
+ boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1)
282
+ masks = masks.unsqueeze(0).repeat(batch, 1)
283
+ conds = torch.cat(positive_embeddings +
284
+ append_conds).unsqueeze(0).repeat(batch, 1, 1)
285
+ return self._set_position(
286
+ box_out.to(device),
287
+ masks.to(device),
288
+ conds.to(device))
289
+
290
+ def set_empty(self, latent_image_shape, device):
291
+ batch, c, h, w = latent_image_shape
292
+ masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1)
293
+ box_out = torch.zeros([self.max_objs, 4],
294
+ device="cpu").repeat(batch, 1, 1)
295
+ conds = torch.zeros([self.max_objs, self.key_dim],
296
+ device="cpu").repeat(batch, 1, 1)
297
+ return self._set_position(
298
+ box_out.to(device),
299
+ masks.to(device),
300
+ conds.to(device))
301
+
302
+
303
+ def load_gligen(sd):
304
+ sd_k = sd.keys()
305
+ output_list = []
306
+ key_dim = 768
307
+ for a in ["input_blocks", "middle_block", "output_blocks"]:
308
+ for b in range(20):
309
+ k_temp = filter(lambda k: "{}.{}.".format(a, b)
310
+ in k and ".fuser." in k, sd_k)
311
+ k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp)
312
+
313
+ n_sd = {}
314
+ for k in k_temp:
315
+ n_sd[k[1]] = sd[k[0]]
316
+ if len(n_sd) > 0:
317
+ query_dim = n_sd["linear.weight"].shape[0]
318
+ key_dim = n_sd["linear.weight"].shape[1]
319
+
320
+ if key_dim == 768: # SD1.x
321
+ n_heads = 8
322
+ d_head = query_dim // n_heads
323
+ else:
324
+ d_head = 64
325
+ n_heads = query_dim // d_head
326
+
327
+ gated = GatedSelfAttentionDense(
328
+ query_dim, key_dim, n_heads, d_head)
329
+ gated.load_state_dict(n_sd, strict=False)
330
+ output_list.append(gated)
331
+
332
+ if "position_net.null_positive_feature" in sd_k:
333
+ in_dim = sd["position_net.null_positive_feature"].shape[0]
334
+ out_dim = sd["position_net.linears.4.weight"].shape[0]
335
+
336
+ class WeightsLoader(torch.nn.Module):
337
+ pass
338
+ w = WeightsLoader()
339
+ w.position_net = PositionNet(in_dim, out_dim)
340
+ w.load_state_dict(sd, strict=False)
341
+
342
+ gligen = Gligen(output_list, w.position_net, key_dim)
343
+ return gligen
content/flux/totoro/k_diffusion/deis.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Taken from: https://github.com/zju-pi/diff-sampler/blob/main/gits-main/solver_utils.py
2
+ #under Apache 2 license
3
+ import torch
4
+ import numpy as np
5
+
6
+ # A pytorch reimplementation of DEIS (https://github.com/qsh-zh/deis).
7
+ #############################
8
+ ### Utils for DEIS solver ###
9
+ #############################
10
+ #----------------------------------------------------------------------------
11
+ # Transfer from the input time (sigma) used in EDM to that (t) used in DEIS.
12
+
13
+ def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80):
14
+ vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
15
+ vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
16
+ vp_beta_d = 2 * (np.log(torch.tensor(sigma_min).cpu() ** 2 + 1) / epsilon_s - np.log(torch.tensor(sigma_max).cpu() ** 2 + 1)) / (epsilon_s - 1)
17
+ vp_beta_min = np.log(torch.tensor(sigma_max).cpu() ** 2 + 1) - 0.5 * vp_beta_d
18
+ t_steps = vp_sigma_inv(vp_beta_d.clone().detach().cpu(), vp_beta_min.clone().detach().cpu())(edm_steps.clone().detach().cpu())
19
+ return t_steps, vp_beta_min, vp_beta_d + vp_beta_min
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ def cal_poly(prev_t, j, taus):
24
+ poly = 1
25
+ for k in range(prev_t.shape[0]):
26
+ if k == j:
27
+ continue
28
+ poly *= (taus - prev_t[k]) / (prev_t[j] - prev_t[k])
29
+ return poly
30
+
31
+ #----------------------------------------------------------------------------
32
+ # Transfer from t to alpha_t.
33
+
34
+ def t2alpha_fn(beta_0, beta_1, t):
35
+ return torch.exp(-0.5 * t ** 2 * (beta_1 - beta_0) - t * beta_0)
36
+
37
+ #----------------------------------------------------------------------------
38
+
39
+ def cal_intergrand(beta_0, beta_1, taus):
40
+ with torch.inference_mode(mode=False):
41
+ taus = taus.clone()
42
+ beta_0 = beta_0.clone()
43
+ beta_1 = beta_1.clone()
44
+ with torch.enable_grad():
45
+ taus.requires_grad_(True)
46
+ alpha = t2alpha_fn(beta_0, beta_1, taus)
47
+ log_alpha = alpha.log()
48
+ log_alpha.sum().backward()
49
+ d_log_alpha_dtau = taus.grad
50
+ integrand = -0.5 * d_log_alpha_dtau / torch.sqrt(alpha * (1 - alpha))
51
+ return integrand
52
+
53
+ #----------------------------------------------------------------------------
54
+
55
+ def get_deis_coeff_list(t_steps, max_order, N=10000, deis_mode='tab'):
56
+ """
57
+ Get the coefficient list for DEIS sampling.
58
+
59
+ Args:
60
+ t_steps: A pytorch tensor. The time steps for sampling.
61
+ max_order: A `int`. Maximum order of the solver. 1 <= max_order <= 4
62
+ N: A `int`. Use how many points to perform the numerical integration when deis_mode=='tab'.
63
+ deis_mode: A `str`. Select between 'tab' and 'rhoab'. Type of DEIS.
64
+ Returns:
65
+ A pytorch tensor. A batch of generated samples or sampling trajectories if return_inters=True.
66
+ """
67
+ if deis_mode == 'tab':
68
+ t_steps, beta_0, beta_1 = edm2t(t_steps)
69
+ C = []
70
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
71
+ order = min(i+1, max_order)
72
+ if order == 1:
73
+ C.append([])
74
+ else:
75
+ taus = torch.linspace(t_cur, t_next, N) # split the interval for integral appximation
76
+ dtau = (t_next - t_cur) / N
77
+ prev_t = t_steps[[i - k for k in range(order)]]
78
+ coeff_temp = []
79
+ integrand = cal_intergrand(beta_0, beta_1, taus)
80
+ for j in range(order):
81
+ poly = cal_poly(prev_t, j, taus)
82
+ coeff_temp.append(torch.sum(integrand * poly) * dtau)
83
+ C.append(coeff_temp)
84
+
85
+ elif deis_mode == 'rhoab':
86
+ # Analytical solution, second order
87
+ def get_def_intergral_2(a, b, start, end, c):
88
+ coeff = (end**3 - start**3) / 3 - (end**2 - start**2) * (a + b) / 2 + (end - start) * a * b
89
+ return coeff / ((c - a) * (c - b))
90
+
91
+ # Analytical solution, third order
92
+ def get_def_intergral_3(a, b, c, start, end, d):
93
+ coeff = (end**4 - start**4) / 4 - (end**3 - start**3) * (a + b + c) / 3 \
94
+ + (end**2 - start**2) * (a*b + a*c + b*c) / 2 - (end - start) * a * b * c
95
+ return coeff / ((d - a) * (d - b) * (d - c))
96
+
97
+ C = []
98
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
99
+ order = min(i, max_order)
100
+ if order == 0:
101
+ C.append([])
102
+ else:
103
+ prev_t = t_steps[[i - k for k in range(order+1)]]
104
+ if order == 1:
105
+ coeff_cur = ((t_next - prev_t[1])**2 - (t_cur - prev_t[1])**2) / (2 * (t_cur - prev_t[1]))
106
+ coeff_prev1 = (t_next - t_cur)**2 / (2 * (prev_t[1] - t_cur))
107
+ coeff_temp = [coeff_cur, coeff_prev1]
108
+ elif order == 2:
109
+ coeff_cur = get_def_intergral_2(prev_t[1], prev_t[2], t_cur, t_next, t_cur)
110
+ coeff_prev1 = get_def_intergral_2(t_cur, prev_t[2], t_cur, t_next, prev_t[1])
111
+ coeff_prev2 = get_def_intergral_2(t_cur, prev_t[1], t_cur, t_next, prev_t[2])
112
+ coeff_temp = [coeff_cur, coeff_prev1, coeff_prev2]
113
+ elif order == 3:
114
+ coeff_cur = get_def_intergral_3(prev_t[1], prev_t[2], prev_t[3], t_cur, t_next, t_cur)
115
+ coeff_prev1 = get_def_intergral_3(t_cur, prev_t[2], prev_t[3], t_cur, t_next, prev_t[1])
116
+ coeff_prev2 = get_def_intergral_3(t_cur, prev_t[1], prev_t[3], t_cur, t_next, prev_t[2])
117
+ coeff_prev3 = get_def_intergral_3(t_cur, prev_t[1], prev_t[2], t_cur, t_next, prev_t[3])
118
+ coeff_temp = [coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3]
119
+ C.append(coeff_temp)
120
+ return C
121
+
content/flux/totoro/k_diffusion/sampling.py ADDED
@@ -0,0 +1,1049 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from scipy import integrate
4
+ import torch
5
+ from torch import nn
6
+ import torchsde
7
+ from tqdm.auto import trange, tqdm
8
+
9
+ from . import utils
10
+ from . import deis
11
+ import totoro.model_patcher
12
+
13
+ def append_zero(x):
14
+ return torch.cat([x, x.new_zeros([1])])
15
+
16
+
17
+ def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
18
+ """Constructs the noise schedule of Karras et al. (2022)."""
19
+ ramp = torch.linspace(0, 1, n, device=device)
20
+ min_inv_rho = sigma_min ** (1 / rho)
21
+ max_inv_rho = sigma_max ** (1 / rho)
22
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
23
+ return append_zero(sigmas).to(device)
24
+
25
+
26
+ def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
27
+ """Constructs an exponential noise schedule."""
28
+ sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
29
+ return append_zero(sigmas)
30
+
31
+
32
+ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
33
+ """Constructs an polynomial in log sigma noise schedule."""
34
+ ramp = torch.linspace(1, 0, n, device=device) ** rho
35
+ sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
36
+ return append_zero(sigmas)
37
+
38
+
39
+ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
40
+ """Constructs a continuous VP noise schedule."""
41
+ t = torch.linspace(1, eps_s, n, device=device)
42
+ sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
43
+ return append_zero(sigmas)
44
+
45
+
46
+ def to_d(x, sigma, denoised):
47
+ """Converts a denoiser output to a Karras ODE derivative."""
48
+ return (x - denoised) / utils.append_dims(sigma, x.ndim)
49
+
50
+
51
+ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
52
+ """Calculates the noise level (sigma_down) to step down to and the amount
53
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
54
+ if not eta:
55
+ return sigma_to, 0.
56
+ sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
57
+ sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
58
+ return sigma_down, sigma_up
59
+
60
+
61
+ def default_noise_sampler(x):
62
+ return lambda sigma, sigma_next: torch.randn_like(x)
63
+
64
+
65
+ class BatchedBrownianTree:
66
+ """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
67
+
68
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
69
+ self.cpu_tree = True
70
+ if "cpu" in kwargs:
71
+ self.cpu_tree = kwargs.pop("cpu")
72
+ t0, t1, self.sign = self.sort(t0, t1)
73
+ w0 = kwargs.get('w0', torch.zeros_like(x))
74
+ if seed is None:
75
+ seed = torch.randint(0, 2 ** 63 - 1, []).item()
76
+ self.batched = True
77
+ try:
78
+ assert len(seed) == x.shape[0]
79
+ w0 = w0[0]
80
+ except TypeError:
81
+ seed = [seed]
82
+ self.batched = False
83
+ if self.cpu_tree:
84
+ self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
85
+ else:
86
+ self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
87
+
88
+ @staticmethod
89
+ def sort(a, b):
90
+ return (a, b, 1) if a < b else (b, a, -1)
91
+
92
+ def __call__(self, t0, t1):
93
+ t0, t1, sign = self.sort(t0, t1)
94
+ if self.cpu_tree:
95
+ w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
96
+ else:
97
+ w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
98
+
99
+ return w if self.batched else w[0]
100
+
101
+
102
+ class BrownianTreeNoiseSampler:
103
+ """A noise sampler backed by a torchsde.BrownianTree.
104
+
105
+ Args:
106
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
107
+ random samples.
108
+ sigma_min (float): The low end of the valid interval.
109
+ sigma_max (float): The high end of the valid interval.
110
+ seed (int or List[int]): The random seed. If a list of seeds is
111
+ supplied instead of a single integer, then the noise sampler will
112
+ use one BrownianTree per batch item, each with its own seed.
113
+ transform (callable): A function that maps sigma to the sampler's
114
+ internal timestep.
115
+ """
116
+
117
+ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
118
+ self.transform = transform
119
+ t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
120
+ self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
121
+
122
+ def __call__(self, sigma, sigma_next):
123
+ t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
124
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
125
+
126
+
127
+ @torch.no_grad()
128
+ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
129
+ """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
130
+ extra_args = {} if extra_args is None else extra_args
131
+ s_in = x.new_ones([x.shape[0]])
132
+ for i in trange(len(sigmas) - 1, disable=disable):
133
+ if s_churn > 0:
134
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
135
+ sigma_hat = sigmas[i] * (gamma + 1)
136
+ else:
137
+ gamma = 0
138
+ sigma_hat = sigmas[i]
139
+
140
+ if gamma > 0:
141
+ eps = torch.randn_like(x) * s_noise
142
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
143
+ denoised = model(x, sigma_hat * s_in, **extra_args)
144
+ d = to_d(x, sigma_hat, denoised)
145
+ if callback is not None:
146
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
147
+ dt = sigmas[i + 1] - sigma_hat
148
+ # Euler method
149
+ x = x + d * dt
150
+ return x
151
+
152
+
153
+ @torch.no_grad()
154
+ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
155
+ """Ancestral sampling with Euler method steps."""
156
+ extra_args = {} if extra_args is None else extra_args
157
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
158
+ s_in = x.new_ones([x.shape[0]])
159
+ for i in trange(len(sigmas) - 1, disable=disable):
160
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
161
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
162
+ if callback is not None:
163
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
164
+ d = to_d(x, sigmas[i], denoised)
165
+ # Euler method
166
+ dt = sigma_down - sigmas[i]
167
+ x = x + d * dt
168
+ if sigmas[i + 1] > 0:
169
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
170
+ return x
171
+
172
+
173
+ @torch.no_grad()
174
+ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
175
+ """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
176
+ extra_args = {} if extra_args is None else extra_args
177
+ s_in = x.new_ones([x.shape[0]])
178
+ for i in trange(len(sigmas) - 1, disable=disable):
179
+ if s_churn > 0:
180
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
181
+ sigma_hat = sigmas[i] * (gamma + 1)
182
+ else:
183
+ gamma = 0
184
+ sigma_hat = sigmas[i]
185
+
186
+ sigma_hat = sigmas[i] * (gamma + 1)
187
+ if gamma > 0:
188
+ eps = torch.randn_like(x) * s_noise
189
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
190
+ denoised = model(x, sigma_hat * s_in, **extra_args)
191
+ d = to_d(x, sigma_hat, denoised)
192
+ if callback is not None:
193
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
194
+ dt = sigmas[i + 1] - sigma_hat
195
+ if sigmas[i + 1] == 0:
196
+ # Euler method
197
+ x = x + d * dt
198
+ else:
199
+ # Heun's method
200
+ x_2 = x + d * dt
201
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
202
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
203
+ d_prime = (d + d_2) / 2
204
+ x = x + d_prime * dt
205
+ return x
206
+
207
+
208
+ @torch.no_grad()
209
+ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
210
+ """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
211
+ extra_args = {} if extra_args is None else extra_args
212
+ s_in = x.new_ones([x.shape[0]])
213
+ for i in trange(len(sigmas) - 1, disable=disable):
214
+ if s_churn > 0:
215
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
216
+ sigma_hat = sigmas[i] * (gamma + 1)
217
+ else:
218
+ gamma = 0
219
+ sigma_hat = sigmas[i]
220
+
221
+ if gamma > 0:
222
+ eps = torch.randn_like(x) * s_noise
223
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
224
+ denoised = model(x, sigma_hat * s_in, **extra_args)
225
+ d = to_d(x, sigma_hat, denoised)
226
+ if callback is not None:
227
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
228
+ if sigmas[i + 1] == 0:
229
+ # Euler method
230
+ dt = sigmas[i + 1] - sigma_hat
231
+ x = x + d * dt
232
+ else:
233
+ # DPM-Solver-2
234
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
235
+ dt_1 = sigma_mid - sigma_hat
236
+ dt_2 = sigmas[i + 1] - sigma_hat
237
+ x_2 = x + d * dt_1
238
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
239
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
240
+ x = x + d_2 * dt_2
241
+ return x
242
+
243
+
244
+ @torch.no_grad()
245
+ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
246
+ """Ancestral sampling with DPM-Solver second-order steps."""
247
+ extra_args = {} if extra_args is None else extra_args
248
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
249
+ s_in = x.new_ones([x.shape[0]])
250
+ for i in trange(len(sigmas) - 1, disable=disable):
251
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
252
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
253
+ if callback is not None:
254
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
255
+ d = to_d(x, sigmas[i], denoised)
256
+ if sigma_down == 0:
257
+ # Euler method
258
+ dt = sigma_down - sigmas[i]
259
+ x = x + d * dt
260
+ else:
261
+ # DPM-Solver-2
262
+ sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
263
+ dt_1 = sigma_mid - sigmas[i]
264
+ dt_2 = sigma_down - sigmas[i]
265
+ x_2 = x + d * dt_1
266
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
267
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
268
+ x = x + d_2 * dt_2
269
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
270
+ return x
271
+
272
+
273
+ def linear_multistep_coeff(order, t, i, j):
274
+ if order - 1 > i:
275
+ raise ValueError(f'Order {order} too high for step {i}')
276
+ def fn(tau):
277
+ prod = 1.
278
+ for k in range(order):
279
+ if j == k:
280
+ continue
281
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
282
+ return prod
283
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
284
+
285
+
286
+ @torch.no_grad()
287
+ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
288
+ extra_args = {} if extra_args is None else extra_args
289
+ s_in = x.new_ones([x.shape[0]])
290
+ sigmas_cpu = sigmas.detach().cpu().numpy()
291
+ ds = []
292
+ for i in trange(len(sigmas) - 1, disable=disable):
293
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
294
+ d = to_d(x, sigmas[i], denoised)
295
+ ds.append(d)
296
+ if len(ds) > order:
297
+ ds.pop(0)
298
+ if callback is not None:
299
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
300
+ cur_order = min(i + 1, order)
301
+ coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
302
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
303
+ return x
304
+
305
+
306
+ class PIDStepSizeController:
307
+ """A PID controller for ODE adaptive step size control."""
308
+ def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
309
+ self.h = h
310
+ self.b1 = (pcoeff + icoeff + dcoeff) / order
311
+ self.b2 = -(pcoeff + 2 * dcoeff) / order
312
+ self.b3 = dcoeff / order
313
+ self.accept_safety = accept_safety
314
+ self.eps = eps
315
+ self.errs = []
316
+
317
+ def limiter(self, x):
318
+ return 1 + math.atan(x - 1)
319
+
320
+ def propose_step(self, error):
321
+ inv_error = 1 / (float(error) + self.eps)
322
+ if not self.errs:
323
+ self.errs = [inv_error, inv_error, inv_error]
324
+ self.errs[0] = inv_error
325
+ factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3
326
+ factor = self.limiter(factor)
327
+ accept = factor >= self.accept_safety
328
+ if accept:
329
+ self.errs[2] = self.errs[1]
330
+ self.errs[1] = self.errs[0]
331
+ self.h *= factor
332
+ return accept
333
+
334
+
335
+ class DPMSolver(nn.Module):
336
+ """DPM-Solver. See https://arxiv.org/abs/2206.00927."""
337
+
338
+ def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
339
+ super().__init__()
340
+ self.model = model
341
+ self.extra_args = {} if extra_args is None else extra_args
342
+ self.eps_callback = eps_callback
343
+ self.info_callback = info_callback
344
+
345
+ def t(self, sigma):
346
+ return -sigma.log()
347
+
348
+ def sigma(self, t):
349
+ return t.neg().exp()
350
+
351
+ def eps(self, eps_cache, key, x, t, *args, **kwargs):
352
+ if key in eps_cache:
353
+ return eps_cache[key], eps_cache
354
+ sigma = self.sigma(t) * x.new_ones([x.shape[0]])
355
+ eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
356
+ if self.eps_callback is not None:
357
+ self.eps_callback()
358
+ return eps, {key: eps, **eps_cache}
359
+
360
+ def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
361
+ eps_cache = {} if eps_cache is None else eps_cache
362
+ h = t_next - t
363
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
364
+ x_1 = x - self.sigma(t_next) * h.expm1() * eps
365
+ return x_1, eps_cache
366
+
367
+ def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
368
+ eps_cache = {} if eps_cache is None else eps_cache
369
+ h = t_next - t
370
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
371
+ s1 = t + r1 * h
372
+ u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
373
+ eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
374
+ x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
375
+ return x_2, eps_cache
376
+
377
+ def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
378
+ eps_cache = {} if eps_cache is None else eps_cache
379
+ h = t_next - t
380
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
381
+ s1 = t + r1 * h
382
+ s2 = t + r2 * h
383
+ u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
384
+ eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
385
+ u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
386
+ eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
387
+ x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
388
+ return x_3, eps_cache
389
+
390
+ def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
391
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
392
+ if not t_end > t_start and eta:
393
+ raise ValueError('eta must be 0 for reverse sampling')
394
+
395
+ m = math.floor(nfe / 3) + 1
396
+ ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
397
+
398
+ if nfe % 3 == 0:
399
+ orders = [3] * (m - 2) + [2, 1]
400
+ else:
401
+ orders = [3] * (m - 1) + [nfe % 3]
402
+
403
+ for i in range(len(orders)):
404
+ eps_cache = {}
405
+ t, t_next = ts[i], ts[i + 1]
406
+ if eta:
407
+ sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
408
+ t_next_ = torch.minimum(t_end, self.t(sd))
409
+ su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
410
+ else:
411
+ t_next_, su = t_next, 0.
412
+
413
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
414
+ denoised = x - self.sigma(t) * eps
415
+ if self.info_callback is not None:
416
+ self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
417
+
418
+ if orders[i] == 1:
419
+ x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
420
+ elif orders[i] == 2:
421
+ x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
422
+ else:
423
+ x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
424
+
425
+ x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
426
+
427
+ return x
428
+
429
+ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
430
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
431
+ if order not in {2, 3}:
432
+ raise ValueError('order should be 2 or 3')
433
+ forward = t_end > t_start
434
+ if not forward and eta:
435
+ raise ValueError('eta must be 0 for reverse sampling')
436
+ h_init = abs(h_init) * (1 if forward else -1)
437
+ atol = torch.tensor(atol)
438
+ rtol = torch.tensor(rtol)
439
+ s = t_start
440
+ x_prev = x
441
+ accept = True
442
+ pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
443
+ info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
444
+
445
+ while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
446
+ eps_cache = {}
447
+ t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
448
+ if eta:
449
+ sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
450
+ t_ = torch.minimum(t_end, self.t(sd))
451
+ su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
452
+ else:
453
+ t_, su = t, 0.
454
+
455
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
456
+ denoised = x - self.sigma(s) * eps
457
+
458
+ if order == 2:
459
+ x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
460
+ x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
461
+ else:
462
+ x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
463
+ x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
464
+ delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
465
+ error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
466
+ accept = pid.propose_step(error)
467
+ if accept:
468
+ x_prev = x_low
469
+ x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
470
+ s = t
471
+ info['n_accept'] += 1
472
+ else:
473
+ info['n_reject'] += 1
474
+ info['nfe'] += order
475
+ info['steps'] += 1
476
+
477
+ if self.info_callback is not None:
478
+ self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
479
+
480
+ return x, info
481
+
482
+
483
+ @torch.no_grad()
484
+ def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
485
+ """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
486
+ if sigma_min <= 0 or sigma_max <= 0:
487
+ raise ValueError('sigma_min and sigma_max must not be 0')
488
+ with tqdm(total=n, disable=disable) as pbar:
489
+ dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
490
+ if callback is not None:
491
+ dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
492
+ return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)
493
+
494
+
495
+ @torch.no_grad()
496
+ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
497
+ """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
498
+ if sigma_min <= 0 or sigma_max <= 0:
499
+ raise ValueError('sigma_min and sigma_max must not be 0')
500
+ with tqdm(disable=disable) as pbar:
501
+ dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
502
+ if callback is not None:
503
+ dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
504
+ x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
505
+ if return_info:
506
+ return x, info
507
+ return x
508
+
509
+
510
+ @torch.no_grad()
511
+ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
512
+ """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
513
+ extra_args = {} if extra_args is None else extra_args
514
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
515
+ s_in = x.new_ones([x.shape[0]])
516
+ sigma_fn = lambda t: t.neg().exp()
517
+ t_fn = lambda sigma: sigma.log().neg()
518
+
519
+ for i in trange(len(sigmas) - 1, disable=disable):
520
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
521
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
522
+ if callback is not None:
523
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
524
+ if sigma_down == 0:
525
+ # Euler method
526
+ d = to_d(x, sigmas[i], denoised)
527
+ dt = sigma_down - sigmas[i]
528
+ x = x + d * dt
529
+ else:
530
+ # DPM-Solver++(2S)
531
+ t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
532
+ r = 1 / 2
533
+ h = t_next - t
534
+ s = t + r * h
535
+ x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
536
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
537
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
538
+ # Noise addition
539
+ if sigmas[i + 1] > 0:
540
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
541
+ return x
542
+
543
+
544
+ @torch.no_grad()
545
+ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
546
+ """DPM-Solver++ (stochastic)."""
547
+ if len(sigmas) <= 1:
548
+ return x
549
+
550
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
551
+ seed = extra_args.get("seed", None)
552
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
553
+ extra_args = {} if extra_args is None else extra_args
554
+ s_in = x.new_ones([x.shape[0]])
555
+ sigma_fn = lambda t: t.neg().exp()
556
+ t_fn = lambda sigma: sigma.log().neg()
557
+
558
+ for i in trange(len(sigmas) - 1, disable=disable):
559
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
560
+ if callback is not None:
561
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
562
+ if sigmas[i + 1] == 0:
563
+ # Euler method
564
+ d = to_d(x, sigmas[i], denoised)
565
+ dt = sigmas[i + 1] - sigmas[i]
566
+ x = x + d * dt
567
+ else:
568
+ # DPM-Solver++
569
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
570
+ h = t_next - t
571
+ s = t + h * r
572
+ fac = 1 / (2 * r)
573
+
574
+ # Step 1
575
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
576
+ s_ = t_fn(sd)
577
+ x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
578
+ x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
579
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
580
+
581
+ # Step 2
582
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
583
+ t_next_ = t_fn(sd)
584
+ denoised_d = (1 - fac) * denoised + fac * denoised_2
585
+ x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
586
+ x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
587
+ return x
588
+
589
+
590
+ @torch.no_grad()
591
+ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
592
+ """DPM-Solver++(2M)."""
593
+ extra_args = {} if extra_args is None else extra_args
594
+ s_in = x.new_ones([x.shape[0]])
595
+ sigma_fn = lambda t: t.neg().exp()
596
+ t_fn = lambda sigma: sigma.log().neg()
597
+ old_denoised = None
598
+
599
+ for i in trange(len(sigmas) - 1, disable=disable):
600
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
601
+ if callback is not None:
602
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
603
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
604
+ h = t_next - t
605
+ if old_denoised is None or sigmas[i + 1] == 0:
606
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
607
+ else:
608
+ h_last = t - t_fn(sigmas[i - 1])
609
+ r = h_last / h
610
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
611
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
612
+ old_denoised = denoised
613
+ return x
614
+
615
+ @torch.no_grad()
616
+ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
617
+ """DPM-Solver++(2M) SDE."""
618
+ if len(sigmas) <= 1:
619
+ return x
620
+
621
+ if solver_type not in {'heun', 'midpoint'}:
622
+ raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
623
+
624
+ seed = extra_args.get("seed", None)
625
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
626
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
627
+ extra_args = {} if extra_args is None else extra_args
628
+ s_in = x.new_ones([x.shape[0]])
629
+
630
+ old_denoised = None
631
+ h_last = None
632
+ h = None
633
+
634
+ for i in trange(len(sigmas) - 1, disable=disable):
635
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
636
+ if callback is not None:
637
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
638
+ if sigmas[i + 1] == 0:
639
+ # Denoising step
640
+ x = denoised
641
+ else:
642
+ # DPM-Solver++(2M) SDE
643
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
644
+ h = s - t
645
+ eta_h = eta * h
646
+
647
+ x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
648
+
649
+ if old_denoised is not None:
650
+ r = h_last / h
651
+ if solver_type == 'heun':
652
+ x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
653
+ elif solver_type == 'midpoint':
654
+ x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
655
+
656
+ if eta:
657
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
658
+
659
+ old_denoised = denoised
660
+ h_last = h
661
+ return x
662
+
663
+ @torch.no_grad()
664
+ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
665
+ """DPM-Solver++(3M) SDE."""
666
+
667
+ if len(sigmas) <= 1:
668
+ return x
669
+
670
+ seed = extra_args.get("seed", None)
671
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
672
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
673
+ extra_args = {} if extra_args is None else extra_args
674
+ s_in = x.new_ones([x.shape[0]])
675
+
676
+ denoised_1, denoised_2 = None, None
677
+ h, h_1, h_2 = None, None, None
678
+
679
+ for i in trange(len(sigmas) - 1, disable=disable):
680
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
681
+ if callback is not None:
682
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
683
+ if sigmas[i + 1] == 0:
684
+ # Denoising step
685
+ x = denoised
686
+ else:
687
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
688
+ h = s - t
689
+ h_eta = h * (eta + 1)
690
+
691
+ x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
692
+
693
+ if h_2 is not None:
694
+ r0 = h_1 / h
695
+ r1 = h_2 / h
696
+ d1_0 = (denoised - denoised_1) / r0
697
+ d1_1 = (denoised_1 - denoised_2) / r1
698
+ d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
699
+ d2 = (d1_0 - d1_1) / (r0 + r1)
700
+ phi_2 = h_eta.neg().expm1() / h_eta + 1
701
+ phi_3 = phi_2 / h_eta - 0.5
702
+ x = x + phi_2 * d1 - phi_3 * d2
703
+ elif h_1 is not None:
704
+ r = h_1 / h
705
+ d = (denoised - denoised_1) / r
706
+ phi_2 = h_eta.neg().expm1() / h_eta + 1
707
+ x = x + phi_2 * d
708
+
709
+ if eta:
710
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
711
+
712
+ denoised_1, denoised_2 = denoised, denoised_1
713
+ h_1, h_2 = h, h_1
714
+ return x
715
+
716
+ @torch.no_grad()
717
+ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
718
+ if len(sigmas) <= 1:
719
+ return x
720
+
721
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
722
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
723
+ return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
724
+
725
+ @torch.no_grad()
726
+ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
727
+ if len(sigmas) <= 1:
728
+ return x
729
+
730
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
731
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
732
+ return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
733
+
734
+ @torch.no_grad()
735
+ def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
736
+ if len(sigmas) <= 1:
737
+ return x
738
+
739
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
740
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
741
+ return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
742
+
743
+
744
+ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
745
+ alpha_cumprod = 1 / ((sigma * sigma) + 1)
746
+ alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1)
747
+ alpha = (alpha_cumprod / alpha_cumprod_prev)
748
+
749
+ mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt())
750
+ if sigma_prev > 0:
751
+ mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
752
+ return mu
753
+
754
+ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
755
+ extra_args = {} if extra_args is None else extra_args
756
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
757
+ s_in = x.new_ones([x.shape[0]])
758
+
759
+ for i in trange(len(sigmas) - 1, disable=disable):
760
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
761
+ if callback is not None:
762
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
763
+ x = step_function(x / torch.sqrt(1.0 + sigmas[i] ** 2.0), sigmas[i], sigmas[i + 1], (x - denoised) / sigmas[i], noise_sampler)
764
+ if sigmas[i + 1] != 0:
765
+ x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2.0)
766
+ return x
767
+
768
+
769
+ @torch.no_grad()
770
+ def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
771
+ return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
772
+
773
+ @torch.no_grad()
774
+ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
775
+ extra_args = {} if extra_args is None else extra_args
776
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
777
+ s_in = x.new_ones([x.shape[0]])
778
+ for i in trange(len(sigmas) - 1, disable=disable):
779
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
780
+ if callback is not None:
781
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
782
+
783
+ x = denoised
784
+ if sigmas[i + 1] > 0:
785
+ x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x)
786
+ return x
787
+
788
+
789
+
790
+ @torch.no_grad()
791
+ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
792
+ extra_args = {} if extra_args is None else extra_args
793
+ s_in = x.new_ones([x.shape[0]])
794
+ s_end = sigmas[-1]
795
+ for i in trange(len(sigmas) - 1, disable=disable):
796
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
797
+ eps = torch.randn_like(x) * s_noise
798
+ sigma_hat = sigmas[i] * (gamma + 1)
799
+ if gamma > 0:
800
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
801
+ denoised = model(x, sigma_hat * s_in, **extra_args)
802
+ d = to_d(x, sigma_hat, denoised)
803
+ if callback is not None:
804
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
805
+ dt = sigmas[i + 1] - sigma_hat
806
+ if sigmas[i + 1] == s_end:
807
+ # Euler method
808
+ x = x + d * dt
809
+ elif sigmas[i + 2] == s_end:
810
+
811
+ # Heun's method
812
+ x_2 = x + d * dt
813
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
814
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
815
+
816
+ w = 2 * sigmas[0]
817
+ w2 = sigmas[i+1]/w
818
+ w1 = 1 - w2
819
+
820
+ d_prime = d * w1 + d_2 * w2
821
+
822
+
823
+ x = x + d_prime * dt
824
+
825
+ else:
826
+ # Heun++
827
+ x_2 = x + d * dt
828
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
829
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
830
+ dt_2 = sigmas[i + 2] - sigmas[i + 1]
831
+
832
+ x_3 = x_2 + d_2 * dt_2
833
+ denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args)
834
+ d_3 = to_d(x_3, sigmas[i + 2], denoised_3)
835
+
836
+ w = 3 * sigmas[0]
837
+ w2 = sigmas[i + 1] / w
838
+ w3 = sigmas[i + 2] / w
839
+ w1 = 1 - w2 - w3
840
+
841
+ d_prime = w1 * d + w2 * d_2 + w3 * d_3
842
+ x = x + d_prime * dt
843
+ return x
844
+
845
+
846
+ #From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
847
+ #under Apache 2 license
848
+ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
849
+ extra_args = {} if extra_args is None else extra_args
850
+ s_in = x.new_ones([x.shape[0]])
851
+
852
+ x_next = x
853
+
854
+ buffer_model = []
855
+ for i in trange(len(sigmas) - 1, disable=disable):
856
+ t_cur = sigmas[i]
857
+ t_next = sigmas[i + 1]
858
+
859
+ x_cur = x_next
860
+
861
+ denoised = model(x_cur, t_cur * s_in, **extra_args)
862
+ if callback is not None:
863
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
864
+
865
+ d_cur = (x_cur - denoised) / t_cur
866
+
867
+ order = min(max_order, i+1)
868
+ if order == 1: # First Euler step.
869
+ x_next = x_cur + (t_next - t_cur) * d_cur
870
+ elif order == 2: # Use one history point.
871
+ x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
872
+ elif order == 3: # Use two history points.
873
+ x_next = x_cur + (t_next - t_cur) * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) / 12
874
+ elif order == 4: # Use three history points.
875
+ x_next = x_cur + (t_next - t_cur) * (55 * d_cur - 59 * buffer_model[-1] + 37 * buffer_model[-2] - 9 * buffer_model[-3]) / 24
876
+
877
+ if len(buffer_model) == max_order - 1:
878
+ for k in range(max_order - 2):
879
+ buffer_model[k] = buffer_model[k+1]
880
+ buffer_model[-1] = d_cur
881
+ else:
882
+ buffer_model.append(d_cur)
883
+
884
+ return x_next
885
+
886
+ #From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
887
+ #under Apache 2 license
888
+ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
889
+ extra_args = {} if extra_args is None else extra_args
890
+ s_in = x.new_ones([x.shape[0]])
891
+
892
+ x_next = x
893
+ t_steps = sigmas
894
+
895
+ buffer_model = []
896
+ for i in trange(len(sigmas) - 1, disable=disable):
897
+ t_cur = sigmas[i]
898
+ t_next = sigmas[i + 1]
899
+
900
+ x_cur = x_next
901
+
902
+ denoised = model(x_cur, t_cur * s_in, **extra_args)
903
+ if callback is not None:
904
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
905
+
906
+ d_cur = (x_cur - denoised) / t_cur
907
+
908
+ order = min(max_order, i+1)
909
+ if order == 1: # First Euler step.
910
+ x_next = x_cur + (t_next - t_cur) * d_cur
911
+ elif order == 2: # Use one history point.
912
+ h_n = (t_next - t_cur)
913
+ h_n_1 = (t_cur - t_steps[i-1])
914
+ coeff1 = (2 + (h_n / h_n_1)) / 2
915
+ coeff2 = -(h_n / h_n_1) / 2
916
+ x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1])
917
+ elif order == 3: # Use two history points.
918
+ h_n = (t_next - t_cur)
919
+ h_n_1 = (t_cur - t_steps[i-1])
920
+ h_n_2 = (t_steps[i-1] - t_steps[i-2])
921
+ temp = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
922
+ coeff1 = (2 + (h_n / h_n_1)) / 2 + temp
923
+ coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp
924
+ coeff3 = temp * h_n_1 / h_n_2
925
+ x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2])
926
+ elif order == 4: # Use three history points.
927
+ h_n = (t_next - t_cur)
928
+ h_n_1 = (t_cur - t_steps[i-1])
929
+ h_n_2 = (t_steps[i-1] - t_steps[i-2])
930
+ h_n_3 = (t_steps[i-2] - t_steps[i-3])
931
+ temp1 = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
932
+ temp2 = ((1 - h_n / (3 * (h_n + h_n_1))) / 2 + (1 - h_n / (2 * (h_n + h_n_1))) * h_n / (6 * (h_n + h_n_1 + h_n_2))) \
933
+ * (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3))
934
+ coeff1 = (2 + (h_n / h_n_1)) / 2 + temp1 + temp2
935
+ coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp1 - (1 + (h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3)))) * temp2
936
+ coeff3 = temp1 * h_n_1 / h_n_2 + ((h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * (1 + h_n_2 / h_n_3)) * temp2
937
+ coeff4 = -temp2 * (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * h_n_1 / h_n_2
938
+ x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2] + coeff4 * buffer_model[-3])
939
+
940
+ if len(buffer_model) == max_order - 1:
941
+ for k in range(max_order - 2):
942
+ buffer_model[k] = buffer_model[k+1]
943
+ buffer_model[-1] = d_cur.detach()
944
+ else:
945
+ buffer_model.append(d_cur.detach())
946
+
947
+ return x_next
948
+
949
+ #From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
950
+ #under Apache 2 license
951
+ @torch.no_grad()
952
+ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=3, deis_mode='tab'):
953
+ extra_args = {} if extra_args is None else extra_args
954
+ s_in = x.new_ones([x.shape[0]])
955
+
956
+ x_next = x
957
+ t_steps = sigmas
958
+
959
+ coeff_list = deis.get_deis_coeff_list(t_steps, max_order, deis_mode=deis_mode)
960
+
961
+ buffer_model = []
962
+ for i in trange(len(sigmas) - 1, disable=disable):
963
+ t_cur = sigmas[i]
964
+ t_next = sigmas[i + 1]
965
+
966
+ x_cur = x_next
967
+
968
+ denoised = model(x_cur, t_cur * s_in, **extra_args)
969
+ if callback is not None:
970
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
971
+
972
+ d_cur = (x_cur - denoised) / t_cur
973
+
974
+ order = min(max_order, i+1)
975
+ if t_next <= 0:
976
+ order = 1
977
+
978
+ if order == 1: # First Euler step.
979
+ x_next = x_cur + (t_next - t_cur) * d_cur
980
+ elif order == 2: # Use one history point.
981
+ coeff_cur, coeff_prev1 = coeff_list[i]
982
+ x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1]
983
+ elif order == 3: # Use two history points.
984
+ coeff_cur, coeff_prev1, coeff_prev2 = coeff_list[i]
985
+ x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2]
986
+ elif order == 4: # Use three history points.
987
+ coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3 = coeff_list[i]
988
+ x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] + coeff_prev3 * buffer_model[-3]
989
+
990
+ if len(buffer_model) == max_order - 1:
991
+ for k in range(max_order - 2):
992
+ buffer_model[k] = buffer_model[k+1]
993
+ buffer_model[-1] = d_cur.detach()
994
+ else:
995
+ buffer_model.append(d_cur.detach())
996
+
997
+ return x_next
998
+
999
+ @torch.no_grad()
1000
+ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
1001
+ extra_args = {} if extra_args is None else extra_args
1002
+
1003
+ temp = [0]
1004
+ def post_cfg_function(args):
1005
+ temp[0] = args["uncond_denoised"]
1006
+ return args["denoised"]
1007
+
1008
+ model_options = extra_args.get("model_options", {}).copy()
1009
+ extra_args["model_options"] = totoro.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
1010
+
1011
+ s_in = x.new_ones([x.shape[0]])
1012
+ for i in trange(len(sigmas) - 1, disable=disable):
1013
+ sigma_hat = sigmas[i]
1014
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1015
+ d = to_d(x, sigma_hat, temp[0])
1016
+ if callback is not None:
1017
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1018
+ dt = sigmas[i + 1] - sigma_hat
1019
+ # Euler method
1020
+ x = denoised + d * sigmas[i + 1]
1021
+ return x
1022
+
1023
+ @torch.no_grad()
1024
+ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
1025
+ """Ancestral sampling with Euler method steps."""
1026
+ extra_args = {} if extra_args is None else extra_args
1027
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
1028
+
1029
+ temp = [0]
1030
+ def post_cfg_function(args):
1031
+ temp[0] = args["uncond_denoised"]
1032
+ return args["denoised"]
1033
+
1034
+ model_options = extra_args.get("model_options", {}).copy()
1035
+ extra_args["model_options"] = totoro.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
1036
+
1037
+ s_in = x.new_ones([x.shape[0]])
1038
+ for i in trange(len(sigmas) - 1, disable=disable):
1039
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1040
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
1041
+ if callback is not None:
1042
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1043
+ d = to_d(x, sigmas[i], temp[0])
1044
+ # Euler method
1045
+ dt = sigma_down - sigmas[i]
1046
+ x = denoised + d * sigma_down
1047
+ if sigmas[i + 1] > 0:
1048
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
1049
+ return x
content/flux/totoro/k_diffusion/utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ import hashlib
3
+ import math
4
+ from pathlib import Path
5
+ import shutil
6
+ import urllib
7
+ import warnings
8
+
9
+ from PIL import Image
10
+ import torch
11
+ from torch import nn, optim
12
+ from torch.utils import data
13
+
14
+
15
+ def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
16
+ """Apply passed in transforms for HuggingFace Datasets."""
17
+ images = [transform(image.convert(mode)) for image in examples[image_key]]
18
+ return {image_key: images}
19
+
20
+
21
+ def append_dims(x, target_dims):
22
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
23
+ dims_to_append = target_dims - x.ndim
24
+ if dims_to_append < 0:
25
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
26
+ expanded = x[(...,) + (None,) * dims_to_append]
27
+ # MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
28
+ # https://github.com/pytorch/pytorch/issues/84364
29
+ return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
30
+
31
+
32
+ def n_params(module):
33
+ """Returns the number of trainable parameters in a module."""
34
+ return sum(p.numel() for p in module.parameters())
35
+
36
+
37
+ def download_file(path, url, digest=None):
38
+ """Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
39
+ path = Path(path)
40
+ path.parent.mkdir(parents=True, exist_ok=True)
41
+ if not path.exists():
42
+ with urllib.request.urlopen(url) as response, open(path, 'wb') as f:
43
+ shutil.copyfileobj(response, f)
44
+ if digest is not None:
45
+ file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest()
46
+ if digest != file_digest:
47
+ raise OSError(f'hash of {path} (url: {url}) failed to validate')
48
+ return path
49
+
50
+
51
+ @contextmanager
52
+ def train_mode(model, mode=True):
53
+ """A context manager that places a model into training mode and restores
54
+ the previous mode on exit."""
55
+ modes = [module.training for module in model.modules()]
56
+ try:
57
+ yield model.train(mode)
58
+ finally:
59
+ for i, module in enumerate(model.modules()):
60
+ module.training = modes[i]
61
+
62
+
63
+ def eval_mode(model):
64
+ """A context manager that places a model into evaluation mode and restores
65
+ the previous mode on exit."""
66
+ return train_mode(model, False)
67
+
68
+
69
+ @torch.no_grad()
70
+ def ema_update(model, averaged_model, decay):
71
+ """Incorporates updated model parameters into an exponential moving averaged
72
+ version of a model. It should be called after each optimizer step."""
73
+ model_params = dict(model.named_parameters())
74
+ averaged_params = dict(averaged_model.named_parameters())
75
+ assert model_params.keys() == averaged_params.keys()
76
+
77
+ for name, param in model_params.items():
78
+ averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)
79
+
80
+ model_buffers = dict(model.named_buffers())
81
+ averaged_buffers = dict(averaged_model.named_buffers())
82
+ assert model_buffers.keys() == averaged_buffers.keys()
83
+
84
+ for name, buf in model_buffers.items():
85
+ averaged_buffers[name].copy_(buf)
86
+
87
+
88
+ class EMAWarmup:
89
+ """Implements an EMA warmup using an inverse decay schedule.
90
+ If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
91
+ good values for models you plan to train for a million or more steps (reaches decay
92
+ factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
93
+ you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
94
+ 215.4k steps).
95
+ Args:
96
+ inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
97
+ power (float): Exponential factor of EMA warmup. Default: 1.
98
+ min_value (float): The minimum EMA decay rate. Default: 0.
99
+ max_value (float): The maximum EMA decay rate. Default: 1.
100
+ start_at (int): The epoch to start averaging at. Default: 0.
101
+ last_epoch (int): The index of last epoch. Default: 0.
102
+ """
103
+
104
+ def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0,
105
+ last_epoch=0):
106
+ self.inv_gamma = inv_gamma
107
+ self.power = power
108
+ self.min_value = min_value
109
+ self.max_value = max_value
110
+ self.start_at = start_at
111
+ self.last_epoch = last_epoch
112
+
113
+ def state_dict(self):
114
+ """Returns the state of the class as a :class:`dict`."""
115
+ return dict(self.__dict__.items())
116
+
117
+ def load_state_dict(self, state_dict):
118
+ """Loads the class's state.
119
+ Args:
120
+ state_dict (dict): scaler state. Should be an object returned
121
+ from a call to :meth:`state_dict`.
122
+ """
123
+ self.__dict__.update(state_dict)
124
+
125
+ def get_value(self):
126
+ """Gets the current EMA decay rate."""
127
+ epoch = max(0, self.last_epoch - self.start_at)
128
+ value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
129
+ return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
130
+
131
+ def step(self):
132
+ """Updates the step count."""
133
+ self.last_epoch += 1
134
+
135
+
136
+ class InverseLR(optim.lr_scheduler._LRScheduler):
137
+ """Implements an inverse decay learning rate schedule with an optional exponential
138
+ warmup. When last_epoch=-1, sets initial lr as lr.
139
+ inv_gamma is the number of steps/epochs required for the learning rate to decay to
140
+ (1 / 2)**power of its original value.
141
+ Args:
142
+ optimizer (Optimizer): Wrapped optimizer.
143
+ inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
144
+ power (float): Exponential factor of learning rate decay. Default: 1.
145
+ warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
146
+ Default: 0.
147
+ min_lr (float): The minimum learning rate. Default: 0.
148
+ last_epoch (int): The index of last epoch. Default: -1.
149
+ verbose (bool): If ``True``, prints a message to stdout for
150
+ each update. Default: ``False``.
151
+ """
152
+
153
+ def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,
154
+ last_epoch=-1, verbose=False):
155
+ self.inv_gamma = inv_gamma
156
+ self.power = power
157
+ if not 0. <= warmup < 1:
158
+ raise ValueError('Invalid value for warmup')
159
+ self.warmup = warmup
160
+ self.min_lr = min_lr
161
+ super().__init__(optimizer, last_epoch, verbose)
162
+
163
+ def get_lr(self):
164
+ if not self._get_lr_called_within_step:
165
+ warnings.warn("To get the last learning rate computed by the scheduler, "
166
+ "please use `get_last_lr()`.")
167
+
168
+ return self._get_closed_form_lr()
169
+
170
+ def _get_closed_form_lr(self):
171
+ warmup = 1 - self.warmup ** (self.last_epoch + 1)
172
+ lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
173
+ return [warmup * max(self.min_lr, base_lr * lr_mult)
174
+ for base_lr in self.base_lrs]
175
+
176
+
177
+ class ExponentialLR(optim.lr_scheduler._LRScheduler):
178
+ """Implements an exponential learning rate schedule with an optional exponential
179
+ warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
180
+ continuously by decay (default 0.5) every num_steps steps.
181
+ Args:
182
+ optimizer (Optimizer): Wrapped optimizer.
183
+ num_steps (float): The number of steps to decay the learning rate by decay in.
184
+ decay (float): The factor by which to decay the learning rate every num_steps
185
+ steps. Default: 0.5.
186
+ warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
187
+ Default: 0.
188
+ min_lr (float): The minimum learning rate. Default: 0.
189
+ last_epoch (int): The index of last epoch. Default: -1.
190
+ verbose (bool): If ``True``, prints a message to stdout for
191
+ each update. Default: ``False``.
192
+ """
193
+
194
+ def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,
195
+ last_epoch=-1, verbose=False):
196
+ self.num_steps = num_steps
197
+ self.decay = decay
198
+ if not 0. <= warmup < 1:
199
+ raise ValueError('Invalid value for warmup')
200
+ self.warmup = warmup
201
+ self.min_lr = min_lr
202
+ super().__init__(optimizer, last_epoch, verbose)
203
+
204
+ def get_lr(self):
205
+ if not self._get_lr_called_within_step:
206
+ warnings.warn("To get the last learning rate computed by the scheduler, "
207
+ "please use `get_last_lr()`.")
208
+
209
+ return self._get_closed_form_lr()
210
+
211
+ def _get_closed_form_lr(self):
212
+ warmup = 1 - self.warmup ** (self.last_epoch + 1)
213
+ lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
214
+ return [warmup * max(self.min_lr, base_lr * lr_mult)
215
+ for base_lr in self.base_lrs]
216
+
217
+
218
+ def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
219
+ """Draws samples from an lognormal distribution."""
220
+ return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp()
221
+
222
+
223
+ def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
224
+ """Draws samples from an optionally truncated log-logistic distribution."""
225
+ min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
226
+ max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
227
+ min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
228
+ max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
229
+ u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf
230
+ return u.logit().mul(scale).add(loc).exp().to(dtype)
231
+
232
+
233
+ def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
234
+ """Draws samples from an log-uniform distribution."""
235
+ min_value = math.log(min_value)
236
+ max_value = math.log(max_value)
237
+ return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()
238
+
239
+
240
+ def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
241
+ """Draws samples from a truncated v-diffusion training timestep distribution."""
242
+ min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
243
+ max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
244
+ u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
245
+ return torch.tan(u * math.pi / 2) * sigma_data
246
+
247
+
248
+ def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32):
249
+ """Draws samples from a split lognormal distribution."""
250
+ n = torch.randn(shape, device=device, dtype=dtype).abs()
251
+ u = torch.rand(shape, device=device, dtype=dtype)
252
+ n_left = n * -scale_1 + loc
253
+ n_right = n * scale_2 + loc
254
+ ratio = scale_1 / (scale_1 + scale_2)
255
+ return torch.where(u < ratio, n_left, n_right).exp()
256
+
257
+
258
+ class FolderOfImages(data.Dataset):
259
+ """Recursively finds all images in a directory. It does not support
260
+ classes/targets."""
261
+
262
+ IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}
263
+
264
+ def __init__(self, root, transform=None):
265
+ super().__init__()
266
+ self.root = Path(root)
267
+ self.transform = nn.Identity() if transform is None else transform
268
+ self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS)
269
+
270
+ def __repr__(self):
271
+ return f'FolderOfImages(root="{self.root}", len: {len(self)})'
272
+
273
+ def __len__(self):
274
+ return len(self.paths)
275
+
276
+ def __getitem__(self, key):
277
+ path = self.paths[key]
278
+ with open(path, 'rb') as f:
279
+ image = Image.open(f).convert('RGB')
280
+ image = self.transform(image)
281
+ return image,
282
+
283
+
284
+ class CSVLogger:
285
+ def __init__(self, filename, columns):
286
+ self.filename = Path(filename)
287
+ self.columns = columns
288
+ if self.filename.exists():
289
+ self.file = open(self.filename, 'a')
290
+ else:
291
+ self.file = open(self.filename, 'w')
292
+ self.write(*self.columns)
293
+
294
+ def write(self, *args):
295
+ print(*args, sep=',', file=self.file, flush=True)
296
+
297
+
298
+ @contextmanager
299
+ def tf32_mode(cudnn=None, matmul=None):
300
+ """A context manager that sets whether TF32 is allowed on cuDNN or matmul."""
301
+ cudnn_old = torch.backends.cudnn.allow_tf32
302
+ matmul_old = torch.backends.cuda.matmul.allow_tf32
303
+ try:
304
+ if cudnn is not None:
305
+ torch.backends.cudnn.allow_tf32 = cudnn
306
+ if matmul is not None:
307
+ torch.backends.cuda.matmul.allow_tf32 = matmul
308
+ yield
309
+ finally:
310
+ if cudnn is not None:
311
+ torch.backends.cudnn.allow_tf32 = cudnn_old
312
+ if matmul is not None:
313
+ torch.backends.cuda.matmul.allow_tf32 = matmul_old
content/flux/totoro/latent_formats.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class LatentFormat:
4
+ scale_factor = 1.0
5
+ latent_channels = 4
6
+ latent_rgb_factors = None
7
+ taesd_decoder_name = None
8
+
9
+ def process_in(self, latent):
10
+ return latent * self.scale_factor
11
+
12
+ def process_out(self, latent):
13
+ return latent / self.scale_factor
14
+
15
+ class SD15(LatentFormat):
16
+ def __init__(self, scale_factor=0.18215):
17
+ self.scale_factor = scale_factor
18
+ self.latent_rgb_factors = [
19
+ # R G B
20
+ [ 0.3512, 0.2297, 0.3227],
21
+ [ 0.3250, 0.4974, 0.2350],
22
+ [-0.2829, 0.1762, 0.2721],
23
+ [-0.2120, -0.2616, -0.7177]
24
+ ]
25
+ self.taesd_decoder_name = "taesd_decoder"
26
+
27
+ class SDXL(LatentFormat):
28
+ scale_factor = 0.13025
29
+
30
+ def __init__(self):
31
+ self.latent_rgb_factors = [
32
+ # R G B
33
+ [ 0.3920, 0.4054, 0.4549],
34
+ [-0.2634, -0.0196, 0.0653],
35
+ [ 0.0568, 0.1687, -0.0755],
36
+ [-0.3112, -0.2359, -0.2076]
37
+ ]
38
+ self.taesd_decoder_name = "taesdxl_decoder"
39
+
40
+ class SDXL_Playground_2_5(LatentFormat):
41
+ def __init__(self):
42
+ self.scale_factor = 0.5
43
+ self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1)
44
+ self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1)
45
+
46
+ self.latent_rgb_factors = [
47
+ # R G B
48
+ [ 0.3920, 0.4054, 0.4549],
49
+ [-0.2634, -0.0196, 0.0653],
50
+ [ 0.0568, 0.1687, -0.0755],
51
+ [-0.3112, -0.2359, -0.2076]
52
+ ]
53
+ self.taesd_decoder_name = "taesdxl_decoder"
54
+
55
+ def process_in(self, latent):
56
+ latents_mean = self.latents_mean.to(latent.device, latent.dtype)
57
+ latents_std = self.latents_std.to(latent.device, latent.dtype)
58
+ return (latent - latents_mean) * self.scale_factor / latents_std
59
+
60
+ def process_out(self, latent):
61
+ latents_mean = self.latents_mean.to(latent.device, latent.dtype)
62
+ latents_std = self.latents_std.to(latent.device, latent.dtype)
63
+ return latent * latents_std / self.scale_factor + latents_mean
64
+
65
+
66
+ class SD_X4(LatentFormat):
67
+ def __init__(self):
68
+ self.scale_factor = 0.08333
69
+ self.latent_rgb_factors = [
70
+ [-0.2340, -0.3863, -0.3257],
71
+ [ 0.0994, 0.0885, -0.0908],
72
+ [-0.2833, -0.2349, -0.3741],
73
+ [ 0.2523, -0.0055, -0.1651]
74
+ ]
75
+
76
+ class SC_Prior(LatentFormat):
77
+ latent_channels = 16
78
+ def __init__(self):
79
+ self.scale_factor = 1.0
80
+ self.latent_rgb_factors = [
81
+ [-0.0326, -0.0204, -0.0127],
82
+ [-0.1592, -0.0427, 0.0216],
83
+ [ 0.0873, 0.0638, -0.0020],
84
+ [-0.0602, 0.0442, 0.1304],
85
+ [ 0.0800, -0.0313, -0.1796],
86
+ [-0.0810, -0.0638, -0.1581],
87
+ [ 0.1791, 0.1180, 0.0967],
88
+ [ 0.0740, 0.1416, 0.0432],
89
+ [-0.1745, -0.1888, -0.1373],
90
+ [ 0.2412, 0.1577, 0.0928],
91
+ [ 0.1908, 0.0998, 0.0682],
92
+ [ 0.0209, 0.0365, -0.0092],
93
+ [ 0.0448, -0.0650, -0.1728],
94
+ [-0.1658, -0.1045, -0.1308],
95
+ [ 0.0542, 0.1545, 0.1325],
96
+ [-0.0352, -0.1672, -0.2541]
97
+ ]
98
+
99
+ class SC_B(LatentFormat):
100
+ def __init__(self):
101
+ self.scale_factor = 1.0 / 0.43
102
+ self.latent_rgb_factors = [
103
+ [ 0.1121, 0.2006, 0.1023],
104
+ [-0.2093, -0.0222, -0.0195],
105
+ [-0.3087, -0.1535, 0.0366],
106
+ [ 0.0290, -0.1574, -0.4078]
107
+ ]
108
+
109
+ class SD3(LatentFormat):
110
+ latent_channels = 16
111
+ def __init__(self):
112
+ self.scale_factor = 1.5305
113
+ self.shift_factor = 0.0609
114
+ self.latent_rgb_factors = [
115
+ [-0.0645, 0.0177, 0.1052],
116
+ [ 0.0028, 0.0312, 0.0650],
117
+ [ 0.1848, 0.0762, 0.0360],
118
+ [ 0.0944, 0.0360, 0.0889],
119
+ [ 0.0897, 0.0506, -0.0364],
120
+ [-0.0020, 0.1203, 0.0284],
121
+ [ 0.0855, 0.0118, 0.0283],
122
+ [-0.0539, 0.0658, 0.1047],
123
+ [-0.0057, 0.0116, 0.0700],
124
+ [-0.0412, 0.0281, -0.0039],
125
+ [ 0.1106, 0.1171, 0.1220],
126
+ [-0.0248, 0.0682, -0.0481],
127
+ [ 0.0815, 0.0846, 0.1207],
128
+ [-0.0120, -0.0055, -0.0867],
129
+ [-0.0749, -0.0634, -0.0456],
130
+ [-0.1418, -0.1457, -0.1259]
131
+ ]
132
+ self.taesd_decoder_name = "taesd3_decoder"
133
+
134
+ def process_in(self, latent):
135
+ return (latent - self.shift_factor) * self.scale_factor
136
+
137
+ def process_out(self, latent):
138
+ return (latent / self.scale_factor) + self.shift_factor
139
+
140
+ class StableAudio1(LatentFormat):
141
+ latent_channels = 64
142
+
143
+ class Flux(SD3):
144
+ def __init__(self):
145
+ self.scale_factor = 0.3611
146
+ self.shift_factor = 0.1159
147
+
148
+ def process_in(self, latent):
149
+ return (latent - self.shift_factor) * self.scale_factor
150
+
151
+ def process_out(self, latent):
152
+ return (latent / self.scale_factor) + self.shift_factor
content/flux/totoro/ldm/audio/autoencoder.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code adapted from: https://github.com/Stability-AI/stable-audio-tools
2
+
3
+ import torch
4
+ from torch import nn
5
+ from typing import Literal, Dict, Any
6
+ import math
7
+ import totoro.ops
8
+ ops = totoro.ops.disable_weight_init
9
+
10
+ def vae_sample(mean, scale):
11
+ stdev = nn.functional.softplus(scale) + 1e-4
12
+ var = stdev * stdev
13
+ logvar = torch.log(var)
14
+ latents = torch.randn_like(mean) * stdev + mean
15
+
16
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
17
+
18
+ return latents, kl
19
+
20
+ class VAEBottleneck(nn.Module):
21
+ def __init__(self):
22
+ super().__init__()
23
+ self.is_discrete = False
24
+
25
+ def encode(self, x, return_info=False, **kwargs):
26
+ info = {}
27
+
28
+ mean, scale = x.chunk(2, dim=1)
29
+
30
+ x, kl = vae_sample(mean, scale)
31
+
32
+ info["kl"] = kl
33
+
34
+ if return_info:
35
+ return x, info
36
+ else:
37
+ return x
38
+
39
+ def decode(self, x):
40
+ return x
41
+
42
+
43
+ def snake_beta(x, alpha, beta):
44
+ return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
45
+
46
+ # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
47
+ class SnakeBeta(nn.Module):
48
+
49
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
50
+ super(SnakeBeta, self).__init__()
51
+ self.in_features = in_features
52
+
53
+ # initialize alpha
54
+ self.alpha_logscale = alpha_logscale
55
+ if self.alpha_logscale: # log scale alphas initialized to zeros
56
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
57
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
58
+ else: # linear scale alphas initialized to ones
59
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
60
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
61
+
62
+ # self.alpha.requires_grad = alpha_trainable
63
+ # self.beta.requires_grad = alpha_trainable
64
+
65
+ self.no_div_by_zero = 0.000000001
66
+
67
+ def forward(self, x):
68
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1).to(x.device) # line up with x to [B, C, T]
69
+ beta = self.beta.unsqueeze(0).unsqueeze(-1).to(x.device)
70
+ if self.alpha_logscale:
71
+ alpha = torch.exp(alpha)
72
+ beta = torch.exp(beta)
73
+ x = snake_beta(x, alpha, beta)
74
+
75
+ return x
76
+
77
+ def WNConv1d(*args, **kwargs):
78
+ try:
79
+ return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
80
+ except:
81
+ return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
82
+
83
+ def WNConvTranspose1d(*args, **kwargs):
84
+ try:
85
+ return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
86
+ except:
87
+ return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
88
+
89
+ def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
90
+ if activation == "elu":
91
+ act = torch.nn.ELU()
92
+ elif activation == "snake":
93
+ act = SnakeBeta(channels)
94
+ elif activation == "none":
95
+ act = torch.nn.Identity()
96
+ else:
97
+ raise ValueError(f"Unknown activation {activation}")
98
+
99
+ if antialias:
100
+ act = Activation1d(act)
101
+
102
+ return act
103
+
104
+
105
+ class ResidualUnit(nn.Module):
106
+ def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
107
+ super().__init__()
108
+
109
+ self.dilation = dilation
110
+
111
+ padding = (dilation * (7-1)) // 2
112
+
113
+ self.layers = nn.Sequential(
114
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
115
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
116
+ kernel_size=7, dilation=dilation, padding=padding),
117
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
118
+ WNConv1d(in_channels=out_channels, out_channels=out_channels,
119
+ kernel_size=1)
120
+ )
121
+
122
+ def forward(self, x):
123
+ res = x
124
+
125
+ #x = checkpoint(self.layers, x)
126
+ x = self.layers(x)
127
+
128
+ return x + res
129
+
130
+ class EncoderBlock(nn.Module):
131
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
132
+ super().__init__()
133
+
134
+ self.layers = nn.Sequential(
135
+ ResidualUnit(in_channels=in_channels,
136
+ out_channels=in_channels, dilation=1, use_snake=use_snake),
137
+ ResidualUnit(in_channels=in_channels,
138
+ out_channels=in_channels, dilation=3, use_snake=use_snake),
139
+ ResidualUnit(in_channels=in_channels,
140
+ out_channels=in_channels, dilation=9, use_snake=use_snake),
141
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
142
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
143
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
144
+ )
145
+
146
+ def forward(self, x):
147
+ return self.layers(x)
148
+
149
+ class DecoderBlock(nn.Module):
150
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
151
+ super().__init__()
152
+
153
+ if use_nearest_upsample:
154
+ upsample_layer = nn.Sequential(
155
+ nn.Upsample(scale_factor=stride, mode="nearest"),
156
+ WNConv1d(in_channels=in_channels,
157
+ out_channels=out_channels,
158
+ kernel_size=2*stride,
159
+ stride=1,
160
+ bias=False,
161
+ padding='same')
162
+ )
163
+ else:
164
+ upsample_layer = WNConvTranspose1d(in_channels=in_channels,
165
+ out_channels=out_channels,
166
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
167
+
168
+ self.layers = nn.Sequential(
169
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
170
+ upsample_layer,
171
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
172
+ dilation=1, use_snake=use_snake),
173
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
174
+ dilation=3, use_snake=use_snake),
175
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
176
+ dilation=9, use_snake=use_snake),
177
+ )
178
+
179
+ def forward(self, x):
180
+ return self.layers(x)
181
+
182
+ class OobleckEncoder(nn.Module):
183
+ def __init__(self,
184
+ in_channels=2,
185
+ channels=128,
186
+ latent_dim=32,
187
+ c_mults = [1, 2, 4, 8],
188
+ strides = [2, 4, 8, 8],
189
+ use_snake=False,
190
+ antialias_activation=False
191
+ ):
192
+ super().__init__()
193
+
194
+ c_mults = [1] + c_mults
195
+
196
+ self.depth = len(c_mults)
197
+
198
+ layers = [
199
+ WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
200
+ ]
201
+
202
+ for i in range(self.depth-1):
203
+ layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
204
+
205
+ layers += [
206
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
207
+ WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
208
+ ]
209
+
210
+ self.layers = nn.Sequential(*layers)
211
+
212
+ def forward(self, x):
213
+ return self.layers(x)
214
+
215
+
216
+ class OobleckDecoder(nn.Module):
217
+ def __init__(self,
218
+ out_channels=2,
219
+ channels=128,
220
+ latent_dim=32,
221
+ c_mults = [1, 2, 4, 8],
222
+ strides = [2, 4, 8, 8],
223
+ use_snake=False,
224
+ antialias_activation=False,
225
+ use_nearest_upsample=False,
226
+ final_tanh=True):
227
+ super().__init__()
228
+
229
+ c_mults = [1] + c_mults
230
+
231
+ self.depth = len(c_mults)
232
+
233
+ layers = [
234
+ WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
235
+ ]
236
+
237
+ for i in range(self.depth-1, 0, -1):
238
+ layers += [DecoderBlock(
239
+ in_channels=c_mults[i]*channels,
240
+ out_channels=c_mults[i-1]*channels,
241
+ stride=strides[i-1],
242
+ use_snake=use_snake,
243
+ antialias_activation=antialias_activation,
244
+ use_nearest_upsample=use_nearest_upsample
245
+ )
246
+ ]
247
+
248
+ layers += [
249
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
250
+ WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
251
+ nn.Tanh() if final_tanh else nn.Identity()
252
+ ]
253
+
254
+ self.layers = nn.Sequential(*layers)
255
+
256
+ def forward(self, x):
257
+ return self.layers(x)
258
+
259
+
260
+ class AudioOobleckVAE(nn.Module):
261
+ def __init__(self,
262
+ in_channels=2,
263
+ channels=128,
264
+ latent_dim=64,
265
+ c_mults = [1, 2, 4, 8, 16],
266
+ strides = [2, 4, 4, 8, 8],
267
+ use_snake=True,
268
+ antialias_activation=False,
269
+ use_nearest_upsample=False,
270
+ final_tanh=False):
271
+ super().__init__()
272
+ self.encoder = OobleckEncoder(in_channels, channels, latent_dim * 2, c_mults, strides, use_snake, antialias_activation)
273
+ self.decoder = OobleckDecoder(in_channels, channels, latent_dim, c_mults, strides, use_snake, antialias_activation,
274
+ use_nearest_upsample=use_nearest_upsample, final_tanh=final_tanh)
275
+ self.bottleneck = VAEBottleneck()
276
+
277
+ def encode(self, x):
278
+ return self.bottleneck.encode(self.encoder(x))
279
+
280
+ def decode(self, x):
281
+ return self.decoder(self.bottleneck.decode(x))
282
+
content/flux/totoro/ldm/audio/dit.py ADDED
@@ -0,0 +1,891 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code adapted from: https://github.com/Stability-AI/stable-audio-tools
2
+
3
+ from totoro.ldm.modules.attention import optimized_attention
4
+ import typing as tp
5
+
6
+ import torch
7
+
8
+ from einops import rearrange
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ import math
12
+ import totoro.ops
13
+
14
+ class FourierFeatures(nn.Module):
15
+ def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
16
+ super().__init__()
17
+ assert out_features % 2 == 0
18
+ self.weight = nn.Parameter(torch.empty(
19
+ [out_features // 2, in_features], dtype=dtype, device=device))
20
+
21
+ def forward(self, input):
22
+ f = 2 * math.pi * input @ totoro.ops.cast_to_input(self.weight.T, input)
23
+ return torch.cat([f.cos(), f.sin()], dim=-1)
24
+
25
+ # norms
26
+ class LayerNorm(nn.Module):
27
+ def __init__(self, dim, bias=False, fix_scale=False, dtype=None, device=None):
28
+ """
29
+ bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
30
+ """
31
+ super().__init__()
32
+
33
+ self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
34
+
35
+ if bias:
36
+ self.beta = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
37
+ else:
38
+ self.beta = None
39
+
40
+ def forward(self, x):
41
+ beta = self.beta
42
+ if beta is not None:
43
+ beta = totoro.ops.cast_to_input(beta, x)
44
+ return F.layer_norm(x, x.shape[-1:], weight=totoro.ops.cast_to_input(self.gamma, x), bias=beta)
45
+
46
+ class GLU(nn.Module):
47
+ def __init__(
48
+ self,
49
+ dim_in,
50
+ dim_out,
51
+ activation,
52
+ use_conv = False,
53
+ conv_kernel_size = 3,
54
+ dtype=None,
55
+ device=None,
56
+ operations=None,
57
+ ):
58
+ super().__init__()
59
+ self.act = activation
60
+ self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2), dtype=dtype, device=device)
61
+ self.use_conv = use_conv
62
+
63
+ def forward(self, x):
64
+ if self.use_conv:
65
+ x = rearrange(x, 'b n d -> b d n')
66
+ x = self.proj(x)
67
+ x = rearrange(x, 'b d n -> b n d')
68
+ else:
69
+ x = self.proj(x)
70
+
71
+ x, gate = x.chunk(2, dim = -1)
72
+ return x * self.act(gate)
73
+
74
+ class AbsolutePositionalEmbedding(nn.Module):
75
+ def __init__(self, dim, max_seq_len):
76
+ super().__init__()
77
+ self.scale = dim ** -0.5
78
+ self.max_seq_len = max_seq_len
79
+ self.emb = nn.Embedding(max_seq_len, dim)
80
+
81
+ def forward(self, x, pos = None, seq_start_pos = None):
82
+ seq_len, device = x.shape[1], x.device
83
+ assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
84
+
85
+ if pos is None:
86
+ pos = torch.arange(seq_len, device = device)
87
+
88
+ if seq_start_pos is not None:
89
+ pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
90
+
91
+ pos_emb = self.emb(pos)
92
+ pos_emb = pos_emb * self.scale
93
+ return pos_emb
94
+
95
+ class ScaledSinusoidalEmbedding(nn.Module):
96
+ def __init__(self, dim, theta = 10000):
97
+ super().__init__()
98
+ assert (dim % 2) == 0, 'dimension must be divisible by 2'
99
+ self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
100
+
101
+ half_dim = dim // 2
102
+ freq_seq = torch.arange(half_dim).float() / half_dim
103
+ inv_freq = theta ** -freq_seq
104
+ self.register_buffer('inv_freq', inv_freq, persistent = False)
105
+
106
+ def forward(self, x, pos = None, seq_start_pos = None):
107
+ seq_len, device = x.shape[1], x.device
108
+
109
+ if pos is None:
110
+ pos = torch.arange(seq_len, device = device)
111
+
112
+ if seq_start_pos is not None:
113
+ pos = pos - seq_start_pos[..., None]
114
+
115
+ emb = torch.einsum('i, j -> i j', pos, self.inv_freq)
116
+ emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
117
+ return emb * self.scale
118
+
119
+ class RotaryEmbedding(nn.Module):
120
+ def __init__(
121
+ self,
122
+ dim,
123
+ use_xpos = False,
124
+ scale_base = 512,
125
+ interpolation_factor = 1.,
126
+ base = 10000,
127
+ base_rescale_factor = 1.,
128
+ dtype=None,
129
+ device=None,
130
+ ):
131
+ super().__init__()
132
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
133
+ # has some connection to NTK literature
134
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
135
+ base *= base_rescale_factor ** (dim / (dim - 2))
136
+
137
+ # inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
138
+ self.register_buffer('inv_freq', torch.empty((dim // 2,), device=device, dtype=dtype))
139
+
140
+ assert interpolation_factor >= 1.
141
+ self.interpolation_factor = interpolation_factor
142
+
143
+ if not use_xpos:
144
+ self.register_buffer('scale', None)
145
+ return
146
+
147
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
148
+
149
+ self.scale_base = scale_base
150
+ self.register_buffer('scale', scale)
151
+
152
+ def forward_from_seq_len(self, seq_len, device, dtype):
153
+ # device = self.inv_freq.device
154
+
155
+ t = torch.arange(seq_len, device=device, dtype=dtype)
156
+ return self.forward(t)
157
+
158
+ def forward(self, t):
159
+ # device = self.inv_freq.device
160
+ device = t.device
161
+ dtype = t.dtype
162
+
163
+ # t = t.to(torch.float32)
164
+
165
+ t = t / self.interpolation_factor
166
+
167
+ freqs = torch.einsum('i , j -> i j', t, totoro.ops.cast_to_input(self.inv_freq, t))
168
+ freqs = torch.cat((freqs, freqs), dim = -1)
169
+
170
+ if self.scale is None:
171
+ return freqs, 1.
172
+
173
+ power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
174
+ scale = totoro.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
175
+ scale = torch.cat((scale, scale), dim = -1)
176
+
177
+ return freqs, scale
178
+
179
+ def rotate_half(x):
180
+ x = rearrange(x, '... (j d) -> ... j d', j = 2)
181
+ x1, x2 = x.unbind(dim = -2)
182
+ return torch.cat((-x2, x1), dim = -1)
183
+
184
+ def apply_rotary_pos_emb(t, freqs, scale = 1):
185
+ out_dtype = t.dtype
186
+
187
+ # cast to float32 if necessary for numerical stability
188
+ dtype = t.dtype #reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
189
+ rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
190
+ freqs, t = freqs.to(dtype), t.to(dtype)
191
+ freqs = freqs[-seq_len:, :]
192
+
193
+ if t.ndim == 4 and freqs.ndim == 3:
194
+ freqs = rearrange(freqs, 'b n d -> b 1 n d')
195
+
196
+ # partial rotary embeddings, Wang et al. GPT-J
197
+ t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
198
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
199
+
200
+ t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
201
+
202
+ return torch.cat((t, t_unrotated), dim = -1)
203
+
204
+ class FeedForward(nn.Module):
205
+ def __init__(
206
+ self,
207
+ dim,
208
+ dim_out = None,
209
+ mult = 4,
210
+ no_bias = False,
211
+ glu = True,
212
+ use_conv = False,
213
+ conv_kernel_size = 3,
214
+ zero_init_output = True,
215
+ dtype=None,
216
+ device=None,
217
+ operations=None,
218
+ ):
219
+ super().__init__()
220
+ inner_dim = int(dim * mult)
221
+
222
+ # Default to SwiGLU
223
+
224
+ activation = nn.SiLU()
225
+
226
+ dim_out = dim if dim_out is None else dim_out
227
+
228
+ if glu:
229
+ linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
230
+ else:
231
+ linear_in = nn.Sequential(
232
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
233
+ operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
234
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
235
+ activation
236
+ )
237
+
238
+ linear_out = operations.Linear(inner_dim, dim_out, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device)
239
+
240
+ # # init last linear layer to 0
241
+ # if zero_init_output:
242
+ # nn.init.zeros_(linear_out.weight)
243
+ # if not no_bias:
244
+ # nn.init.zeros_(linear_out.bias)
245
+
246
+
247
+ self.ff = nn.Sequential(
248
+ linear_in,
249
+ Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
250
+ linear_out,
251
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
252
+ )
253
+
254
+ def forward(self, x):
255
+ return self.ff(x)
256
+
257
+ class Attention(nn.Module):
258
+ def __init__(
259
+ self,
260
+ dim,
261
+ dim_heads = 64,
262
+ dim_context = None,
263
+ causal = False,
264
+ zero_init_output=True,
265
+ qk_norm = False,
266
+ natten_kernel_size = None,
267
+ dtype=None,
268
+ device=None,
269
+ operations=None,
270
+ ):
271
+ super().__init__()
272
+ self.dim = dim
273
+ self.dim_heads = dim_heads
274
+ self.causal = causal
275
+
276
+ dim_kv = dim_context if dim_context is not None else dim
277
+
278
+ self.num_heads = dim // dim_heads
279
+ self.kv_heads = dim_kv // dim_heads
280
+
281
+ if dim_context is not None:
282
+ self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
283
+ self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device)
284
+ else:
285
+ self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device)
286
+
287
+ self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
288
+
289
+ # if zero_init_output:
290
+ # nn.init.zeros_(self.to_out.weight)
291
+
292
+ self.qk_norm = qk_norm
293
+
294
+
295
+ def forward(
296
+ self,
297
+ x,
298
+ context = None,
299
+ mask = None,
300
+ context_mask = None,
301
+ rotary_pos_emb = None,
302
+ causal = None
303
+ ):
304
+ h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
305
+
306
+ kv_input = context if has_context else x
307
+
308
+ if hasattr(self, 'to_q'):
309
+ # Use separate linear projections for q and k/v
310
+ q = self.to_q(x)
311
+ q = rearrange(q, 'b n (h d) -> b h n d', h = h)
312
+
313
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
314
+
315
+ k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
316
+ else:
317
+ # Use fused linear projection
318
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
319
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
320
+
321
+ # Normalize q and k for cosine sim attention
322
+ if self.qk_norm:
323
+ q = F.normalize(q, dim=-1)
324
+ k = F.normalize(k, dim=-1)
325
+
326
+ if rotary_pos_emb is not None and not has_context:
327
+ freqs, _ = rotary_pos_emb
328
+
329
+ q_dtype = q.dtype
330
+ k_dtype = k.dtype
331
+
332
+ q = q.to(torch.float32)
333
+ k = k.to(torch.float32)
334
+ freqs = freqs.to(torch.float32)
335
+
336
+ q = apply_rotary_pos_emb(q, freqs)
337
+ k = apply_rotary_pos_emb(k, freqs)
338
+
339
+ q = q.to(q_dtype)
340
+ k = k.to(k_dtype)
341
+
342
+ input_mask = context_mask
343
+
344
+ if input_mask is None and not has_context:
345
+ input_mask = mask
346
+
347
+ # determine masking
348
+ masks = []
349
+ final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
350
+
351
+ if input_mask is not None:
352
+ input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
353
+ masks.append(~input_mask)
354
+
355
+ # Other masks will be added here later
356
+
357
+ if len(masks) > 0:
358
+ final_attn_mask = ~or_reduce(masks)
359
+
360
+ n, device = q.shape[-2], q.device
361
+
362
+ causal = self.causal if causal is None else causal
363
+
364
+ if n == 1 and causal:
365
+ causal = False
366
+
367
+ if h != kv_h:
368
+ # Repeat interleave kv_heads to match q_heads
369
+ heads_per_kv_head = h // kv_h
370
+ k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
371
+
372
+ out = optimized_attention(q, k, v, h, skip_reshape=True)
373
+ out = self.to_out(out)
374
+
375
+ if mask is not None:
376
+ mask = rearrange(mask, 'b n -> b n 1')
377
+ out = out.masked_fill(~mask, 0.)
378
+
379
+ return out
380
+
381
+ class ConformerModule(nn.Module):
382
+ def __init__(
383
+ self,
384
+ dim,
385
+ norm_kwargs = {},
386
+ ):
387
+
388
+ super().__init__()
389
+
390
+ self.dim = dim
391
+
392
+ self.in_norm = LayerNorm(dim, **norm_kwargs)
393
+ self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
394
+ self.glu = GLU(dim, dim, nn.SiLU())
395
+ self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
396
+ self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
397
+ self.swish = nn.SiLU()
398
+ self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
399
+
400
+ def forward(self, x):
401
+ x = self.in_norm(x)
402
+ x = rearrange(x, 'b n d -> b d n')
403
+ x = self.pointwise_conv(x)
404
+ x = rearrange(x, 'b d n -> b n d')
405
+ x = self.glu(x)
406
+ x = rearrange(x, 'b n d -> b d n')
407
+ x = self.depthwise_conv(x)
408
+ x = rearrange(x, 'b d n -> b n d')
409
+ x = self.mid_norm(x)
410
+ x = self.swish(x)
411
+ x = rearrange(x, 'b n d -> b d n')
412
+ x = self.pointwise_conv_2(x)
413
+ x = rearrange(x, 'b d n -> b n d')
414
+
415
+ return x
416
+
417
+ class TransformerBlock(nn.Module):
418
+ def __init__(
419
+ self,
420
+ dim,
421
+ dim_heads = 64,
422
+ cross_attend = False,
423
+ dim_context = None,
424
+ global_cond_dim = None,
425
+ causal = False,
426
+ zero_init_branch_outputs = True,
427
+ conformer = False,
428
+ layer_ix = -1,
429
+ remove_norms = False,
430
+ attn_kwargs = {},
431
+ ff_kwargs = {},
432
+ norm_kwargs = {},
433
+ dtype=None,
434
+ device=None,
435
+ operations=None,
436
+ ):
437
+
438
+ super().__init__()
439
+ self.dim = dim
440
+ self.dim_heads = dim_heads
441
+ self.cross_attend = cross_attend
442
+ self.dim_context = dim_context
443
+ self.causal = causal
444
+
445
+ self.pre_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
446
+
447
+ self.self_attn = Attention(
448
+ dim,
449
+ dim_heads = dim_heads,
450
+ causal = causal,
451
+ zero_init_output=zero_init_branch_outputs,
452
+ dtype=dtype,
453
+ device=device,
454
+ operations=operations,
455
+ **attn_kwargs
456
+ )
457
+
458
+ if cross_attend:
459
+ self.cross_attend_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
460
+ self.cross_attn = Attention(
461
+ dim,
462
+ dim_heads = dim_heads,
463
+ dim_context=dim_context,
464
+ causal = causal,
465
+ zero_init_output=zero_init_branch_outputs,
466
+ dtype=dtype,
467
+ device=device,
468
+ operations=operations,
469
+ **attn_kwargs
470
+ )
471
+
472
+ self.ff_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
473
+ self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations,**ff_kwargs)
474
+
475
+ self.layer_ix = layer_ix
476
+
477
+ self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
478
+
479
+ self.global_cond_dim = global_cond_dim
480
+
481
+ if global_cond_dim is not None:
482
+ self.to_scale_shift_gate = nn.Sequential(
483
+ nn.SiLU(),
484
+ nn.Linear(global_cond_dim, dim * 6, bias=False)
485
+ )
486
+
487
+ nn.init.zeros_(self.to_scale_shift_gate[1].weight)
488
+ #nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
489
+
490
+ def forward(
491
+ self,
492
+ x,
493
+ context = None,
494
+ global_cond=None,
495
+ mask = None,
496
+ context_mask = None,
497
+ rotary_pos_emb = None
498
+ ):
499
+ if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
500
+
501
+ scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
502
+
503
+ # self-attention with adaLN
504
+ residual = x
505
+ x = self.pre_norm(x)
506
+ x = x * (1 + scale_self) + shift_self
507
+ x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
508
+ x = x * torch.sigmoid(1 - gate_self)
509
+ x = x + residual
510
+
511
+ if context is not None:
512
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
513
+
514
+ if self.conformer is not None:
515
+ x = x + self.conformer(x)
516
+
517
+ # feedforward with adaLN
518
+ residual = x
519
+ x = self.ff_norm(x)
520
+ x = x * (1 + scale_ff) + shift_ff
521
+ x = self.ff(x)
522
+ x = x * torch.sigmoid(1 - gate_ff)
523
+ x = x + residual
524
+
525
+ else:
526
+ x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
527
+
528
+ if context is not None:
529
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
530
+
531
+ if self.conformer is not None:
532
+ x = x + self.conformer(x)
533
+
534
+ x = x + self.ff(self.ff_norm(x))
535
+
536
+ return x
537
+
538
+ class ContinuousTransformer(nn.Module):
539
+ def __init__(
540
+ self,
541
+ dim,
542
+ depth,
543
+ *,
544
+ dim_in = None,
545
+ dim_out = None,
546
+ dim_heads = 64,
547
+ cross_attend=False,
548
+ cond_token_dim=None,
549
+ global_cond_dim=None,
550
+ causal=False,
551
+ rotary_pos_emb=True,
552
+ zero_init_branch_outputs=True,
553
+ conformer=False,
554
+ use_sinusoidal_emb=False,
555
+ use_abs_pos_emb=False,
556
+ abs_pos_emb_max_length=10000,
557
+ dtype=None,
558
+ device=None,
559
+ operations=None,
560
+ **kwargs
561
+ ):
562
+
563
+ super().__init__()
564
+
565
+ self.dim = dim
566
+ self.depth = depth
567
+ self.causal = causal
568
+ self.layers = nn.ModuleList([])
569
+
570
+ self.project_in = operations.Linear(dim_in, dim, bias=False, dtype=dtype, device=device) if dim_in is not None else nn.Identity()
571
+ self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
572
+
573
+ if rotary_pos_emb:
574
+ self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32), device=device, dtype=dtype)
575
+ else:
576
+ self.rotary_pos_emb = None
577
+
578
+ self.use_sinusoidal_emb = use_sinusoidal_emb
579
+ if use_sinusoidal_emb:
580
+ self.pos_emb = ScaledSinusoidalEmbedding(dim)
581
+
582
+ self.use_abs_pos_emb = use_abs_pos_emb
583
+ if use_abs_pos_emb:
584
+ self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
585
+
586
+ for i in range(depth):
587
+ self.layers.append(
588
+ TransformerBlock(
589
+ dim,
590
+ dim_heads = dim_heads,
591
+ cross_attend = cross_attend,
592
+ dim_context = cond_token_dim,
593
+ global_cond_dim = global_cond_dim,
594
+ causal = causal,
595
+ zero_init_branch_outputs = zero_init_branch_outputs,
596
+ conformer=conformer,
597
+ layer_ix=i,
598
+ dtype=dtype,
599
+ device=device,
600
+ operations=operations,
601
+ **kwargs
602
+ )
603
+ )
604
+
605
+ def forward(
606
+ self,
607
+ x,
608
+ mask = None,
609
+ prepend_embeds = None,
610
+ prepend_mask = None,
611
+ global_cond = None,
612
+ return_info = False,
613
+ **kwargs
614
+ ):
615
+ batch, seq, device = *x.shape[:2], x.device
616
+
617
+ info = {
618
+ "hidden_states": [],
619
+ }
620
+
621
+ x = self.project_in(x)
622
+
623
+ if prepend_embeds is not None:
624
+ prepend_length, prepend_dim = prepend_embeds.shape[1:]
625
+
626
+ assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
627
+
628
+ x = torch.cat((prepend_embeds, x), dim = -2)
629
+
630
+ if prepend_mask is not None or mask is not None:
631
+ mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
632
+ prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
633
+
634
+ mask = torch.cat((prepend_mask, mask), dim = -1)
635
+
636
+ # Attention layers
637
+
638
+ if self.rotary_pos_emb is not None:
639
+ rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
640
+ else:
641
+ rotary_pos_emb = None
642
+
643
+ if self.use_sinusoidal_emb or self.use_abs_pos_emb:
644
+ x = x + self.pos_emb(x)
645
+
646
+ # Iterate over the transformer layers
647
+ for layer in self.layers:
648
+ x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
649
+ # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
650
+
651
+ if return_info:
652
+ info["hidden_states"].append(x)
653
+
654
+ x = self.project_out(x)
655
+
656
+ if return_info:
657
+ return x, info
658
+
659
+ return x
660
+
661
+ class AudioDiffusionTransformer(nn.Module):
662
+ def __init__(self,
663
+ io_channels=64,
664
+ patch_size=1,
665
+ embed_dim=1536,
666
+ cond_token_dim=768,
667
+ project_cond_tokens=False,
668
+ global_cond_dim=1536,
669
+ project_global_cond=True,
670
+ input_concat_dim=0,
671
+ prepend_cond_dim=0,
672
+ depth=24,
673
+ num_heads=24,
674
+ transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
675
+ global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
676
+ audio_model="",
677
+ dtype=None,
678
+ device=None,
679
+ operations=None,
680
+ **kwargs):
681
+
682
+ super().__init__()
683
+
684
+ self.dtype = dtype
685
+ self.cond_token_dim = cond_token_dim
686
+
687
+ # Timestep embeddings
688
+ timestep_features_dim = 256
689
+
690
+ self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device)
691
+
692
+ self.to_timestep_embed = nn.Sequential(
693
+ operations.Linear(timestep_features_dim, embed_dim, bias=True, dtype=dtype, device=device),
694
+ nn.SiLU(),
695
+ operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device),
696
+ )
697
+
698
+ if cond_token_dim > 0:
699
+ # Conditioning tokens
700
+
701
+ cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
702
+ self.to_cond_embed = nn.Sequential(
703
+ operations.Linear(cond_token_dim, cond_embed_dim, bias=False, dtype=dtype, device=device),
704
+ nn.SiLU(),
705
+ operations.Linear(cond_embed_dim, cond_embed_dim, bias=False, dtype=dtype, device=device)
706
+ )
707
+ else:
708
+ cond_embed_dim = 0
709
+
710
+ if global_cond_dim > 0:
711
+ # Global conditioning
712
+ global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
713
+ self.to_global_embed = nn.Sequential(
714
+ operations.Linear(global_cond_dim, global_embed_dim, bias=False, dtype=dtype, device=device),
715
+ nn.SiLU(),
716
+ operations.Linear(global_embed_dim, global_embed_dim, bias=False, dtype=dtype, device=device)
717
+ )
718
+
719
+ if prepend_cond_dim > 0:
720
+ # Prepend conditioning
721
+ self.to_prepend_embed = nn.Sequential(
722
+ operations.Linear(prepend_cond_dim, embed_dim, bias=False, dtype=dtype, device=device),
723
+ nn.SiLU(),
724
+ operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
725
+ )
726
+
727
+ self.input_concat_dim = input_concat_dim
728
+
729
+ dim_in = io_channels + self.input_concat_dim
730
+
731
+ self.patch_size = patch_size
732
+
733
+ # Transformer
734
+
735
+ self.transformer_type = transformer_type
736
+
737
+ self.global_cond_type = global_cond_type
738
+
739
+ if self.transformer_type == "continuous_transformer":
740
+
741
+ global_dim = None
742
+
743
+ if self.global_cond_type == "adaLN":
744
+ # The global conditioning is projected to the embed_dim already at this point
745
+ global_dim = embed_dim
746
+
747
+ self.transformer = ContinuousTransformer(
748
+ dim=embed_dim,
749
+ depth=depth,
750
+ dim_heads=embed_dim // num_heads,
751
+ dim_in=dim_in * patch_size,
752
+ dim_out=io_channels * patch_size,
753
+ cross_attend = cond_token_dim > 0,
754
+ cond_token_dim = cond_embed_dim,
755
+ global_cond_dim=global_dim,
756
+ dtype=dtype,
757
+ device=device,
758
+ operations=operations,
759
+ **kwargs
760
+ )
761
+ else:
762
+ raise ValueError(f"Unknown transformer type: {self.transformer_type}")
763
+
764
+ self.preprocess_conv = operations.Conv1d(dim_in, dim_in, 1, bias=False, dtype=dtype, device=device)
765
+ self.postprocess_conv = operations.Conv1d(io_channels, io_channels, 1, bias=False, dtype=dtype, device=device)
766
+
767
+ def _forward(
768
+ self,
769
+ x,
770
+ t,
771
+ mask=None,
772
+ cross_attn_cond=None,
773
+ cross_attn_cond_mask=None,
774
+ input_concat_cond=None,
775
+ global_embed=None,
776
+ prepend_cond=None,
777
+ prepend_cond_mask=None,
778
+ return_info=False,
779
+ **kwargs):
780
+
781
+ if cross_attn_cond is not None:
782
+ cross_attn_cond = self.to_cond_embed(cross_attn_cond)
783
+
784
+ if global_embed is not None:
785
+ # Project the global conditioning to the embedding dimension
786
+ global_embed = self.to_global_embed(global_embed)
787
+
788
+ prepend_inputs = None
789
+ prepend_mask = None
790
+ prepend_length = 0
791
+ if prepend_cond is not None:
792
+ # Project the prepend conditioning to the embedding dimension
793
+ prepend_cond = self.to_prepend_embed(prepend_cond)
794
+
795
+ prepend_inputs = prepend_cond
796
+ if prepend_cond_mask is not None:
797
+ prepend_mask = prepend_cond_mask
798
+
799
+ if input_concat_cond is not None:
800
+
801
+ # Interpolate input_concat_cond to the same length as x
802
+ if input_concat_cond.shape[2] != x.shape[2]:
803
+ input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
804
+
805
+ x = torch.cat([x, input_concat_cond], dim=1)
806
+
807
+ # Get the batch of timestep embeddings
808
+ timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None]).to(x.dtype)) # (b, embed_dim)
809
+
810
+ # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
811
+ if global_embed is not None:
812
+ global_embed = global_embed + timestep_embed
813
+ else:
814
+ global_embed = timestep_embed
815
+
816
+ # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
817
+ if self.global_cond_type == "prepend":
818
+ if prepend_inputs is None:
819
+ # Prepend inputs are just the global embed, and the mask is all ones
820
+ prepend_inputs = global_embed.unsqueeze(1)
821
+ prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
822
+ else:
823
+ # Prepend inputs are the prepend conditioning + the global embed
824
+ prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
825
+ prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
826
+
827
+ prepend_length = prepend_inputs.shape[1]
828
+
829
+ x = self.preprocess_conv(x) + x
830
+
831
+ x = rearrange(x, "b c t -> b t c")
832
+
833
+ extra_args = {}
834
+
835
+ if self.global_cond_type == "adaLN":
836
+ extra_args["global_cond"] = global_embed
837
+
838
+ if self.patch_size > 1:
839
+ x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
840
+
841
+ if self.transformer_type == "x-transformers":
842
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs)
843
+ elif self.transformer_type == "continuous_transformer":
844
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
845
+
846
+ if return_info:
847
+ output, info = output
848
+ elif self.transformer_type == "mm_transformer":
849
+ output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs)
850
+
851
+ output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
852
+
853
+ if self.patch_size > 1:
854
+ output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
855
+
856
+ output = self.postprocess_conv(output) + output
857
+
858
+ if return_info:
859
+ return output, info
860
+
861
+ return output
862
+
863
+ def forward(
864
+ self,
865
+ x,
866
+ timestep,
867
+ context=None,
868
+ context_mask=None,
869
+ input_concat_cond=None,
870
+ global_embed=None,
871
+ negative_global_embed=None,
872
+ prepend_cond=None,
873
+ prepend_cond_mask=None,
874
+ mask=None,
875
+ return_info=False,
876
+ control=None,
877
+ transformer_options={},
878
+ **kwargs):
879
+ return self._forward(
880
+ x,
881
+ timestep,
882
+ cross_attn_cond=context,
883
+ cross_attn_cond_mask=context_mask,
884
+ input_concat_cond=input_concat_cond,
885
+ global_embed=global_embed,
886
+ prepend_cond=prepend_cond,
887
+ prepend_cond_mask=prepend_cond_mask,
888
+ mask=mask,
889
+ return_info=return_info,
890
+ **kwargs
891
+ )
content/flux/totoro/ldm/audio/embedders.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code adapted from: https://github.com/Stability-AI/stable-audio-tools
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch import Tensor, einsum
6
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
7
+ from einops import rearrange
8
+ import math
9
+ import totoro.ops
10
+
11
+ class LearnedPositionalEmbedding(nn.Module):
12
+ """Used for continuous time"""
13
+
14
+ def __init__(self, dim: int):
15
+ super().__init__()
16
+ assert (dim % 2) == 0
17
+ half_dim = dim // 2
18
+ self.weights = nn.Parameter(torch.empty(half_dim))
19
+
20
+ def forward(self, x: Tensor) -> Tensor:
21
+ x = rearrange(x, "b -> b 1")
22
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
23
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
24
+ fouriered = torch.cat((x, fouriered), dim=-1)
25
+ return fouriered
26
+
27
+ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
28
+ return nn.Sequential(
29
+ LearnedPositionalEmbedding(dim),
30
+ totoro.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features),
31
+ )
32
+
33
+
34
+ class NumberEmbedder(nn.Module):
35
+ def __init__(
36
+ self,
37
+ features: int,
38
+ dim: int = 256,
39
+ ):
40
+ super().__init__()
41
+ self.features = features
42
+ self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
43
+
44
+ def forward(self, x: Union[List[float], Tensor]) -> Tensor:
45
+ if not torch.is_tensor(x):
46
+ device = next(self.embedding.parameters()).device
47
+ x = torch.tensor(x, device=device)
48
+ assert isinstance(x, Tensor)
49
+ shape = x.shape
50
+ x = rearrange(x, "... -> (...)")
51
+ embedding = self.embedding(x)
52
+ x = embedding.view(*shape, self.features)
53
+ return x # type: ignore
54
+
55
+
56
+ class Conditioner(nn.Module):
57
+ def __init__(
58
+ self,
59
+ dim: int,
60
+ output_dim: int,
61
+ project_out: bool = False
62
+ ):
63
+
64
+ super().__init__()
65
+
66
+ self.dim = dim
67
+ self.output_dim = output_dim
68
+ self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
69
+
70
+ def forward(self, x):
71
+ raise NotImplementedError()
72
+
73
+ class NumberConditioner(Conditioner):
74
+ '''
75
+ Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
76
+ '''
77
+ def __init__(self,
78
+ output_dim: int,
79
+ min_val: float=0,
80
+ max_val: float=1
81
+ ):
82
+ super().__init__(output_dim, output_dim)
83
+
84
+ self.min_val = min_val
85
+ self.max_val = max_val
86
+
87
+ self.embedder = NumberEmbedder(features=output_dim)
88
+
89
+ def forward(self, floats, device=None):
90
+ # Cast the inputs to floats
91
+ floats = [float(x) for x in floats]
92
+
93
+ if device is None:
94
+ device = next(self.embedder.parameters()).device
95
+
96
+ floats = torch.tensor(floats).to(device)
97
+
98
+ floats = floats.clamp(self.min_val, self.max_val)
99
+
100
+ normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
101
+
102
+ # Cast floats to same type as embedder
103
+ embedder_dtype = next(self.embedder.parameters()).dtype
104
+ normalized_floats = normalized_floats.to(embedder_dtype)
105
+
106
+ float_embeds = self.embedder(normalized_floats).unsqueeze(1)
107
+
108
+ return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
content/flux/totoro/ldm/aura/mmdit.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #AuraFlow MMDiT
2
+ #Originally written by the AuraFlow Authors
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from totoro.ldm.modules.attention import optimized_attention
11
+ import totoro.ops
12
+
13
+ def modulate(x, shift, scale):
14
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
15
+
16
+
17
+ def find_multiple(n: int, k: int) -> int:
18
+ if n % k == 0:
19
+ return n
20
+ return n + k - (n % k)
21
+
22
+
23
+ class MLP(nn.Module):
24
+ def __init__(self, dim, hidden_dim=None, dtype=None, device=None, operations=None) -> None:
25
+ super().__init__()
26
+ if hidden_dim is None:
27
+ hidden_dim = 4 * dim
28
+
29
+ n_hidden = int(2 * hidden_dim / 3)
30
+ n_hidden = find_multiple(n_hidden, 256)
31
+
32
+ self.c_fc1 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
33
+ self.c_fc2 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
34
+ self.c_proj = operations.Linear(n_hidden, dim, bias=False, dtype=dtype, device=device)
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
38
+ x = self.c_proj(x)
39
+ return x
40
+
41
+
42
+ class MultiHeadLayerNorm(nn.Module):
43
+ def __init__(self, hidden_size=None, eps=1e-5, dtype=None, device=None):
44
+ # Copy pasta from https://github.com/huggingface/transformers/blob/e5f71ecaae50ea476d1e12351003790273c4b2ed/src/transformers/models/cohere/modeling_cohere.py#L78
45
+
46
+ super().__init__()
47
+ self.weight = nn.Parameter(torch.empty(hidden_size, dtype=dtype, device=device))
48
+ self.variance_epsilon = eps
49
+
50
+ def forward(self, hidden_states):
51
+ input_dtype = hidden_states.dtype
52
+ hidden_states = hidden_states.to(torch.float32)
53
+ mean = hidden_states.mean(-1, keepdim=True)
54
+ variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
55
+ hidden_states = (hidden_states - mean) * torch.rsqrt(
56
+ variance + self.variance_epsilon
57
+ )
58
+ hidden_states = self.weight.to(torch.float32) * hidden_states
59
+ return hidden_states.to(input_dtype)
60
+
61
+ class SingleAttention(nn.Module):
62
+ def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
63
+ super().__init__()
64
+
65
+ self.n_heads = n_heads
66
+ self.head_dim = dim // n_heads
67
+
68
+ # this is for cond
69
+ self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
70
+ self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
71
+ self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
72
+ self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
73
+
74
+ self.q_norm1 = (
75
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
76
+ if mh_qknorm
77
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
78
+ )
79
+ self.k_norm1 = (
80
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
81
+ if mh_qknorm
82
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
83
+ )
84
+
85
+ #@torch.compile()
86
+ def forward(self, c):
87
+
88
+ bsz, seqlen1, _ = c.shape
89
+
90
+ q, k, v = self.w1q(c), self.w1k(c), self.w1v(c)
91
+ q = q.view(bsz, seqlen1, self.n_heads, self.head_dim)
92
+ k = k.view(bsz, seqlen1, self.n_heads, self.head_dim)
93
+ v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
94
+ q, k = self.q_norm1(q), self.k_norm1(k)
95
+
96
+ output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
97
+ c = self.w1o(output)
98
+ return c
99
+
100
+
101
+
102
+ class DoubleAttention(nn.Module):
103
+ def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
104
+ super().__init__()
105
+
106
+ self.n_heads = n_heads
107
+ self.head_dim = dim // n_heads
108
+
109
+ # this is for cond
110
+ self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
111
+ self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
112
+ self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
113
+ self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
114
+
115
+ # this is for x
116
+ self.w2q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
117
+ self.w2k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
118
+ self.w2v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
119
+ self.w2o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
120
+
121
+ self.q_norm1 = (
122
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
123
+ if mh_qknorm
124
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
125
+ )
126
+ self.k_norm1 = (
127
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
128
+ if mh_qknorm
129
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
130
+ )
131
+
132
+ self.q_norm2 = (
133
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
134
+ if mh_qknorm
135
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
136
+ )
137
+ self.k_norm2 = (
138
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
139
+ if mh_qknorm
140
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
141
+ )
142
+
143
+
144
+ #@torch.compile()
145
+ def forward(self, c, x):
146
+
147
+ bsz, seqlen1, _ = c.shape
148
+ bsz, seqlen2, _ = x.shape
149
+ seqlen = seqlen1 + seqlen2
150
+
151
+ cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
152
+ cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)
153
+ ck = ck.view(bsz, seqlen1, self.n_heads, self.head_dim)
154
+ cv = cv.view(bsz, seqlen1, self.n_heads, self.head_dim)
155
+ cq, ck = self.q_norm1(cq), self.k_norm1(ck)
156
+
157
+ xq, xk, xv = self.w2q(x), self.w2k(x), self.w2v(x)
158
+ xq = xq.view(bsz, seqlen2, self.n_heads, self.head_dim)
159
+ xk = xk.view(bsz, seqlen2, self.n_heads, self.head_dim)
160
+ xv = xv.view(bsz, seqlen2, self.n_heads, self.head_dim)
161
+ xq, xk = self.q_norm2(xq), self.k_norm2(xk)
162
+
163
+ # concat all
164
+ q, k, v = (
165
+ torch.cat([cq, xq], dim=1),
166
+ torch.cat([ck, xk], dim=1),
167
+ torch.cat([cv, xv], dim=1),
168
+ )
169
+
170
+ output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
171
+
172
+ c, x = output.split([seqlen1, seqlen2], dim=1)
173
+ c = self.w1o(c)
174
+ x = self.w2o(x)
175
+
176
+ return c, x
177
+
178
+
179
+ class MMDiTBlock(nn.Module):
180
+ def __init__(self, dim, heads=8, global_conddim=1024, is_last=False, dtype=None, device=None, operations=None):
181
+ super().__init__()
182
+
183
+ self.normC1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
184
+ self.normC2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
185
+ if not is_last:
186
+ self.mlpC = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
187
+ self.modC = nn.Sequential(
188
+ nn.SiLU(),
189
+ operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
190
+ )
191
+ else:
192
+ self.modC = nn.Sequential(
193
+ nn.SiLU(),
194
+ operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
195
+ )
196
+
197
+ self.normX1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
198
+ self.normX2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
199
+ self.mlpX = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
200
+ self.modX = nn.Sequential(
201
+ nn.SiLU(),
202
+ operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
203
+ )
204
+
205
+ self.attn = DoubleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
206
+ self.is_last = is_last
207
+
208
+ #@torch.compile()
209
+ def forward(self, c, x, global_cond, **kwargs):
210
+
211
+ cres, xres = c, x
212
+
213
+ cshift_msa, cscale_msa, cgate_msa, cshift_mlp, cscale_mlp, cgate_mlp = (
214
+ self.modC(global_cond).chunk(6, dim=1)
215
+ )
216
+
217
+ c = modulate(self.normC1(c), cshift_msa, cscale_msa)
218
+
219
+ # xpath
220
+ xshift_msa, xscale_msa, xgate_msa, xshift_mlp, xscale_mlp, xgate_mlp = (
221
+ self.modX(global_cond).chunk(6, dim=1)
222
+ )
223
+
224
+ x = modulate(self.normX1(x), xshift_msa, xscale_msa)
225
+
226
+ # attention
227
+ c, x = self.attn(c, x)
228
+
229
+
230
+ c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
231
+ c = cgate_mlp.unsqueeze(1) * self.mlpC(modulate(c, cshift_mlp, cscale_mlp))
232
+ c = cres + c
233
+
234
+ x = self.normX2(xres + xgate_msa.unsqueeze(1) * x)
235
+ x = xgate_mlp.unsqueeze(1) * self.mlpX(modulate(x, xshift_mlp, xscale_mlp))
236
+ x = xres + x
237
+
238
+ return c, x
239
+
240
+ class DiTBlock(nn.Module):
241
+ # like MMDiTBlock, but it only has X
242
+ def __init__(self, dim, heads=8, global_conddim=1024, dtype=None, device=None, operations=None):
243
+ super().__init__()
244
+
245
+ self.norm1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
246
+ self.norm2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
247
+
248
+ self.modCX = nn.Sequential(
249
+ nn.SiLU(),
250
+ operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
251
+ )
252
+
253
+ self.attn = SingleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
254
+ self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
255
+
256
+ #@torch.compile()
257
+ def forward(self, cx, global_cond, **kwargs):
258
+ cxres = cx
259
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
260
+ global_cond
261
+ ).chunk(6, dim=1)
262
+ cx = modulate(self.norm1(cx), shift_msa, scale_msa)
263
+ cx = self.attn(cx)
264
+ cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
265
+ mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
266
+ cx = gate_mlp.unsqueeze(1) * mlpout
267
+
268
+ cx = cxres + cx
269
+
270
+ return cx
271
+
272
+
273
+
274
+ class TimestepEmbedder(nn.Module):
275
+ def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
276
+ super().__init__()
277
+ self.mlp = nn.Sequential(
278
+ operations.Linear(frequency_embedding_size, hidden_size, dtype=dtype, device=device),
279
+ nn.SiLU(),
280
+ operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device),
281
+ )
282
+ self.frequency_embedding_size = frequency_embedding_size
283
+
284
+ @staticmethod
285
+ def timestep_embedding(t, dim, max_period=10000):
286
+ half = dim // 2
287
+ freqs = 1000 * torch.exp(
288
+ -math.log(max_period) * torch.arange(start=0, end=half) / half
289
+ ).to(t.device)
290
+ args = t[:, None] * freqs[None]
291
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
292
+ if dim % 2:
293
+ embedding = torch.cat(
294
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
295
+ )
296
+ return embedding
297
+
298
+ #@torch.compile()
299
+ def forward(self, t, dtype):
300
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
301
+ t_emb = self.mlp(t_freq)
302
+ return t_emb
303
+
304
+
305
+ class MMDiT(nn.Module):
306
+ def __init__(
307
+ self,
308
+ in_channels=4,
309
+ out_channels=4,
310
+ patch_size=2,
311
+ dim=3072,
312
+ n_layers=36,
313
+ n_double_layers=4,
314
+ n_heads=12,
315
+ global_conddim=3072,
316
+ cond_seq_dim=2048,
317
+ max_seq=32 * 32,
318
+ device=None,
319
+ dtype=None,
320
+ operations=None,
321
+ ):
322
+ super().__init__()
323
+ self.dtype = dtype
324
+
325
+ self.t_embedder = TimestepEmbedder(global_conddim, dtype=dtype, device=device, operations=operations)
326
+
327
+ self.cond_seq_linear = operations.Linear(
328
+ cond_seq_dim, dim, bias=False, dtype=dtype, device=device
329
+ ) # linear for something like text sequence.
330
+ self.init_x_linear = operations.Linear(
331
+ patch_size * patch_size * in_channels, dim, dtype=dtype, device=device
332
+ ) # init linear for patchified image.
333
+
334
+ self.positional_encoding = nn.Parameter(torch.empty(1, max_seq, dim, dtype=dtype, device=device))
335
+ self.register_tokens = nn.Parameter(torch.empty(1, 8, dim, dtype=dtype, device=device))
336
+
337
+ self.double_layers = nn.ModuleList([])
338
+ self.single_layers = nn.ModuleList([])
339
+
340
+
341
+ for idx in range(n_double_layers):
342
+ self.double_layers.append(
343
+ MMDiTBlock(dim, n_heads, global_conddim, is_last=(idx == n_layers - 1), dtype=dtype, device=device, operations=operations)
344
+ )
345
+
346
+ for idx in range(n_double_layers, n_layers):
347
+ self.single_layers.append(
348
+ DiTBlock(dim, n_heads, global_conddim, dtype=dtype, device=device, operations=operations)
349
+ )
350
+
351
+
352
+ self.final_linear = operations.Linear(
353
+ dim, patch_size * patch_size * out_channels, bias=False, dtype=dtype, device=device
354
+ )
355
+
356
+ self.modF = nn.Sequential(
357
+ nn.SiLU(),
358
+ operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
359
+ )
360
+
361
+ self.out_channels = out_channels
362
+ self.patch_size = patch_size
363
+ self.n_double_layers = n_double_layers
364
+ self.n_layers = n_layers
365
+
366
+ self.h_max = round(max_seq**0.5)
367
+ self.w_max = round(max_seq**0.5)
368
+
369
+ @torch.no_grad()
370
+ def extend_pe(self, init_dim=(16, 16), target_dim=(64, 64)):
371
+ # extend pe
372
+ pe_data = self.positional_encoding.data.squeeze(0)[: init_dim[0] * init_dim[1]]
373
+
374
+ pe_as_2d = pe_data.view(init_dim[0], init_dim[1], -1).permute(2, 0, 1)
375
+
376
+ # now we need to extend this to target_dim. for this we will use interpolation.
377
+ # we will use torch.nn.functional.interpolate
378
+ pe_as_2d = F.interpolate(
379
+ pe_as_2d.unsqueeze(0), size=target_dim, mode="bilinear"
380
+ )
381
+ pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
382
+ self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
383
+ self.h_max, self.w_max = target_dim
384
+ print("PE extended to", target_dim)
385
+
386
+ def pe_selection_index_based_on_dim(self, h, w):
387
+ h_p, w_p = h // self.patch_size, w // self.patch_size
388
+ original_pe_indexes = torch.arange(self.positional_encoding.shape[1])
389
+ original_pe_indexes = original_pe_indexes.view(self.h_max, self.w_max)
390
+ starth = self.h_max // 2 - h_p // 2
391
+ endh =starth + h_p
392
+ startw = self.w_max // 2 - w_p // 2
393
+ endw = startw + w_p
394
+ original_pe_indexes = original_pe_indexes[
395
+ starth:endh, startw:endw
396
+ ]
397
+ return original_pe_indexes.flatten()
398
+
399
+ def unpatchify(self, x, h, w):
400
+ c = self.out_channels
401
+ p = self.patch_size
402
+
403
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
404
+ x = torch.einsum("nhwpqc->nchpwq", x)
405
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
406
+ return imgs
407
+
408
+ def patchify(self, x):
409
+ B, C, H, W = x.size()
410
+ pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
411
+ pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
412
+
413
+ x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular')
414
+ x = x.view(
415
+ B,
416
+ C,
417
+ (H + 1) // self.patch_size,
418
+ self.patch_size,
419
+ (W + 1) // self.patch_size,
420
+ self.patch_size,
421
+ )
422
+ x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
423
+ return x
424
+
425
+ def apply_pos_embeds(self, x, h, w):
426
+ h = (h + 1) // self.patch_size
427
+ w = (w + 1) // self.patch_size
428
+ max_dim = max(h, w)
429
+
430
+ cur_dim = self.h_max
431
+ pos_encoding = totoro.ops.cast_to_input(self.positional_encoding.reshape(1, cur_dim, cur_dim, -1), x)
432
+
433
+ if max_dim > cur_dim:
434
+ pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1)
435
+ cur_dim = max_dim
436
+
437
+ from_h = (cur_dim - h) // 2
438
+ from_w = (cur_dim - w) // 2
439
+ pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
440
+ return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
441
+
442
+ def forward(self, x, timestep, context, **kwargs):
443
+ # patchify x, add PE
444
+ b, c, h, w = x.shape
445
+
446
+ # pe_indexes = self.pe_selection_index_based_on_dim(h, w)
447
+ # print(pe_indexes, pe_indexes.shape)
448
+
449
+ x = self.init_x_linear(self.patchify(x)) # B, T_x, D
450
+ x = self.apply_pos_embeds(x, h, w)
451
+ # x = x + self.positional_encoding[:, : x.size(1)].to(device=x.device, dtype=x.dtype)
452
+ # x = x + self.positional_encoding[:, pe_indexes].to(device=x.device, dtype=x.dtype)
453
+
454
+ # process conditions for MMDiT Blocks
455
+ c_seq = context # B, T_c, D_c
456
+ t = timestep
457
+
458
+ c = self.cond_seq_linear(c_seq) # B, T_c, D
459
+ c = torch.cat([totoro.ops.cast_to_input(self.register_tokens, c).repeat(c.size(0), 1, 1), c], dim=1)
460
+
461
+ global_cond = self.t_embedder(t, x.dtype) # B, D
462
+
463
+ if len(self.double_layers) > 0:
464
+ for layer in self.double_layers:
465
+ c, x = layer(c, x, global_cond, **kwargs)
466
+
467
+ if len(self.single_layers) > 0:
468
+ c_len = c.size(1)
469
+ cx = torch.cat([c, x], dim=1)
470
+ for layer in self.single_layers:
471
+ cx = layer(cx, global_cond, **kwargs)
472
+
473
+ x = cx[:, c_len:]
474
+
475
+ fshift, fscale = self.modF(global_cond).chunk(2, dim=1)
476
+
477
+ x = modulate(x, fshift, fscale)
478
+ x = self.final_linear(x)
479
+ x = self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w]
480
+ return x
content/flux/totoro/ldm/cascade/common.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of totoroUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from totoro.ldm.modules.attention import optimized_attention
22
+ import totoro.ops
23
+
24
+ class OptimizedAttention(nn.Module):
25
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
26
+ super().__init__()
27
+ self.heads = nhead
28
+
29
+ self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
30
+ self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
31
+ self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
32
+
33
+ self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
34
+
35
+ def forward(self, q, k, v):
36
+ q = self.to_q(q)
37
+ k = self.to_k(k)
38
+ v = self.to_v(v)
39
+
40
+ out = optimized_attention(q, k, v, self.heads)
41
+
42
+ return self.out_proj(out)
43
+
44
+ class Attention2D(nn.Module):
45
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
46
+ super().__init__()
47
+ self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
48
+ # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
49
+
50
+ def forward(self, x, kv, self_attn=False):
51
+ orig_shape = x.shape
52
+ x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
53
+ if self_attn:
54
+ kv = torch.cat([x, kv], dim=1)
55
+ # x = self.attn(x, kv, kv, need_weights=False)[0]
56
+ x = self.attn(x, kv, kv)
57
+ x = x.permute(0, 2, 1).view(*orig_shape)
58
+ return x
59
+
60
+
61
+ def LayerNorm2d_op(operations):
62
+ class LayerNorm2d(operations.LayerNorm):
63
+ def __init__(self, *args, **kwargs):
64
+ super().__init__(*args, **kwargs)
65
+
66
+ def forward(self, x):
67
+ return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
68
+ return LayerNorm2d
69
+
70
+ class GlobalResponseNorm(nn.Module):
71
+ "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
72
+ def __init__(self, dim, dtype=None, device=None):
73
+ super().__init__()
74
+ self.gamma = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
75
+ self.beta = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
76
+
77
+ def forward(self, x):
78
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
79
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
80
+ return totoro.ops.cast_to_input(self.gamma, x) * (x * Nx) + totoro.ops.cast_to_input(self.beta, x) + x
81
+
82
+
83
+ class ResBlock(nn.Module):
84
+ def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2):
85
+ super().__init__()
86
+ self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device)
87
+ # self.depthwise = SAMBlock(c, num_heads, expansion)
88
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
89
+ self.channelwise = nn.Sequential(
90
+ operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device),
91
+ nn.GELU(),
92
+ GlobalResponseNorm(c * 4, dtype=dtype, device=device),
93
+ nn.Dropout(dropout),
94
+ operations.Linear(c * 4, c, dtype=dtype, device=device)
95
+ )
96
+
97
+ def forward(self, x, x_skip=None):
98
+ x_res = x
99
+ x = self.norm(self.depthwise(x))
100
+ if x_skip is not None:
101
+ x = torch.cat([x, x_skip], dim=1)
102
+ x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
103
+ return x + x_res
104
+
105
+
106
+ class AttnBlock(nn.Module):
107
+ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None):
108
+ super().__init__()
109
+ self.self_attn = self_attn
110
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
111
+ self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations)
112
+ self.kv_mapper = nn.Sequential(
113
+ nn.SiLU(),
114
+ operations.Linear(c_cond, c, dtype=dtype, device=device)
115
+ )
116
+
117
+ def forward(self, x, kv):
118
+ kv = self.kv_mapper(kv)
119
+ x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
120
+ return x
121
+
122
+
123
+ class FeedForwardBlock(nn.Module):
124
+ def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None):
125
+ super().__init__()
126
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
127
+ self.channelwise = nn.Sequential(
128
+ operations.Linear(c, c * 4, dtype=dtype, device=device),
129
+ nn.GELU(),
130
+ GlobalResponseNorm(c * 4, dtype=dtype, device=device),
131
+ nn.Dropout(dropout),
132
+ operations.Linear(c * 4, c, dtype=dtype, device=device)
133
+ )
134
+
135
+ def forward(self, x):
136
+ x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
137
+ return x
138
+
139
+
140
+ class TimestepBlock(nn.Module):
141
+ def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None):
142
+ super().__init__()
143
+ self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)
144
+ self.conds = conds
145
+ for cname in conds:
146
+ setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device))
147
+
148
+ def forward(self, x, t):
149
+ t = t.chunk(len(self.conds) + 1, dim=1)
150
+ a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
151
+ for i, c in enumerate(self.conds):
152
+ ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
153
+ a, b = a + ac, b + bc
154
+ return x * (1 + a) + b
content/flux/totoro/ldm/cascade/controlnet.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of totoroUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import torch
20
+ import torchvision
21
+ from torch import nn
22
+ from .common import LayerNorm2d_op
23
+
24
+
25
+ class CNetResBlock(nn.Module):
26
+ def __init__(self, c, dtype=None, device=None, operations=None):
27
+ super().__init__()
28
+ self.blocks = nn.Sequential(
29
+ LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
30
+ nn.GELU(),
31
+ operations.Conv2d(c, c, kernel_size=3, padding=1),
32
+ LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
33
+ nn.GELU(),
34
+ operations.Conv2d(c, c, kernel_size=3, padding=1),
35
+ )
36
+
37
+ def forward(self, x):
38
+ return x + self.blocks(x)
39
+
40
+
41
+ class ControlNet(nn.Module):
42
+ def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None, dtype=None, device=None, operations=nn):
43
+ super().__init__()
44
+ if bottleneck_mode is None:
45
+ bottleneck_mode = 'effnet'
46
+ self.proj_blocks = proj_blocks
47
+ if bottleneck_mode == 'effnet':
48
+ embd_channels = 1280
49
+ self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
50
+ if c_in != 3:
51
+ in_weights = self.backbone[0][0].weight.data
52
+ self.backbone[0][0] = operations.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False, dtype=dtype, device=device)
53
+ if c_in > 3:
54
+ # nn.init.constant_(self.backbone[0][0].weight, 0)
55
+ self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone()
56
+ else:
57
+ self.backbone[0][0].weight.data = in_weights[:, :c_in].clone()
58
+ elif bottleneck_mode == 'simple':
59
+ embd_channels = c_in
60
+ self.backbone = nn.Sequential(
61
+ operations.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1, dtype=dtype, device=device),
62
+ nn.LeakyReLU(0.2, inplace=True),
63
+ operations.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1, dtype=dtype, device=device),
64
+ )
65
+ elif bottleneck_mode == 'large':
66
+ self.backbone = nn.Sequential(
67
+ operations.Conv2d(c_in, 4096 * 4, kernel_size=1, dtype=dtype, device=device),
68
+ nn.LeakyReLU(0.2, inplace=True),
69
+ operations.Conv2d(4096 * 4, 1024, kernel_size=1, dtype=dtype, device=device),
70
+ *[CNetResBlock(1024, dtype=dtype, device=device, operations=operations) for _ in range(8)],
71
+ operations.Conv2d(1024, 1280, kernel_size=1, dtype=dtype, device=device),
72
+ )
73
+ embd_channels = 1280
74
+ else:
75
+ raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}')
76
+ self.projections = nn.ModuleList()
77
+ for _ in range(len(proj_blocks)):
78
+ self.projections.append(nn.Sequential(
79
+ operations.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False, dtype=dtype, device=device),
80
+ nn.LeakyReLU(0.2, inplace=True),
81
+ operations.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False, dtype=dtype, device=device),
82
+ ))
83
+ # nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection
84
+ self.xl = False
85
+ self.input_channels = c_in
86
+ self.unshuffle_amount = 8
87
+
88
+ def forward(self, x):
89
+ x = self.backbone(x)
90
+ proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)]
91
+ for i, idx in enumerate(self.proj_blocks):
92
+ proj_outputs[idx] = self.projections[i](x)
93
+ return {"input": proj_outputs[::-1]}
content/flux/totoro/ldm/cascade/stage_a.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of totoroUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import torch
20
+ from torch import nn
21
+ from torch.autograd import Function
22
+
23
+ class vector_quantize(Function):
24
+ @staticmethod
25
+ def forward(ctx, x, codebook):
26
+ with torch.no_grad():
27
+ codebook_sqr = torch.sum(codebook ** 2, dim=1)
28
+ x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)
29
+
30
+ dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
31
+ _, indices = dist.min(dim=1)
32
+
33
+ ctx.save_for_backward(indices, codebook)
34
+ ctx.mark_non_differentiable(indices)
35
+
36
+ nn = torch.index_select(codebook, 0, indices)
37
+ return nn, indices
38
+
39
+ @staticmethod
40
+ def backward(ctx, grad_output, grad_indices):
41
+ grad_inputs, grad_codebook = None, None
42
+
43
+ if ctx.needs_input_grad[0]:
44
+ grad_inputs = grad_output.clone()
45
+ if ctx.needs_input_grad[1]:
46
+ # Gradient wrt. the codebook
47
+ indices, codebook = ctx.saved_tensors
48
+
49
+ grad_codebook = torch.zeros_like(codebook)
50
+ grad_codebook.index_add_(0, indices, grad_output)
51
+
52
+ return (grad_inputs, grad_codebook)
53
+
54
+
55
+ class VectorQuantize(nn.Module):
56
+ def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
57
+ """
58
+ Takes an input of variable size (as long as the last dimension matches the embedding size).
59
+ Returns one tensor containing the nearest neigbour embeddings to each of the inputs,
60
+ with the same size as the input, vq and commitment components for the loss as a touple
61
+ in the second output and the indices of the quantized vectors in the third:
62
+ quantized, (vq_loss, commit_loss), indices
63
+ """
64
+ super(VectorQuantize, self).__init__()
65
+
66
+ self.codebook = nn.Embedding(k, embedding_size)
67
+ self.codebook.weight.data.uniform_(-1./k, 1./k)
68
+ self.vq = vector_quantize.apply
69
+
70
+ self.ema_decay = ema_decay
71
+ self.ema_loss = ema_loss
72
+ if ema_loss:
73
+ self.register_buffer('ema_element_count', torch.ones(k))
74
+ self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight))
75
+
76
+ def _laplace_smoothing(self, x, epsilon):
77
+ n = torch.sum(x)
78
+ return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
79
+
80
+ def _updateEMA(self, z_e_x, indices):
81
+ mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
82
+ elem_count = mask.sum(dim=0)
83
+ weight_sum = torch.mm(mask.t(), z_e_x)
84
+
85
+ self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
86
+ self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
87
+ self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
88
+
89
+ self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
90
+
91
+ def idx2vq(self, idx, dim=-1):
92
+ q_idx = self.codebook(idx)
93
+ if dim != -1:
94
+ q_idx = q_idx.movedim(-1, dim)
95
+ return q_idx
96
+
97
+ def forward(self, x, get_losses=True, dim=-1):
98
+ if dim != -1:
99
+ x = x.movedim(dim, -1)
100
+ z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
101
+ z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
102
+ vq_loss, commit_loss = None, None
103
+ if self.ema_loss and self.training:
104
+ self._updateEMA(z_e_x.detach(), indices.detach())
105
+ # pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
106
+ z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
107
+ if get_losses:
108
+ vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
109
+ commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
110
+
111
+ z_q_x = z_q_x.view(x.shape)
112
+ if dim != -1:
113
+ z_q_x = z_q_x.movedim(-1, dim)
114
+ return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
115
+
116
+
117
+ class ResBlock(nn.Module):
118
+ def __init__(self, c, c_hidden):
119
+ super().__init__()
120
+ # depthwise/attention
121
+ self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
122
+ self.depthwise = nn.Sequential(
123
+ nn.ReplicationPad2d(1),
124
+ nn.Conv2d(c, c, kernel_size=3, groups=c)
125
+ )
126
+
127
+ # channelwise
128
+ self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
129
+ self.channelwise = nn.Sequential(
130
+ nn.Linear(c, c_hidden),
131
+ nn.GELU(),
132
+ nn.Linear(c_hidden, c),
133
+ )
134
+
135
+ self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
136
+
137
+ # Init weights
138
+ def _basic_init(module):
139
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
140
+ torch.nn.init.xavier_uniform_(module.weight)
141
+ if module.bias is not None:
142
+ nn.init.constant_(module.bias, 0)
143
+
144
+ self.apply(_basic_init)
145
+
146
+ def _norm(self, x, norm):
147
+ return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
148
+
149
+ def forward(self, x):
150
+ mods = self.gammas
151
+
152
+ x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
153
+ try:
154
+ x = x + self.depthwise(x_temp) * mods[2]
155
+ except: #operation not implemented for bf16
156
+ x_temp = self.depthwise[0](x_temp.float()).to(x.dtype)
157
+ x = x + self.depthwise[1](x_temp) * mods[2]
158
+
159
+ x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
160
+ x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
161
+
162
+ return x
163
+
164
+
165
+ class StageA(nn.Module):
166
+ def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192):
167
+ super().__init__()
168
+ self.c_latent = c_latent
169
+ c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
170
+
171
+ # Encoder blocks
172
+ self.in_block = nn.Sequential(
173
+ nn.PixelUnshuffle(2),
174
+ nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
175
+ )
176
+ down_blocks = []
177
+ for i in range(levels):
178
+ if i > 0:
179
+ down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
180
+ block = ResBlock(c_levels[i], c_levels[i] * 4)
181
+ down_blocks.append(block)
182
+ down_blocks.append(nn.Sequential(
183
+ nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
184
+ nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
185
+ ))
186
+ self.down_blocks = nn.Sequential(*down_blocks)
187
+ self.down_blocks[0]
188
+
189
+ self.codebook_size = codebook_size
190
+ self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
191
+
192
+ # Decoder blocks
193
+ up_blocks = [nn.Sequential(
194
+ nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
195
+ )]
196
+ for i in range(levels):
197
+ for j in range(bottleneck_blocks if i == 0 else 1):
198
+ block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
199
+ up_blocks.append(block)
200
+ if i < levels - 1:
201
+ up_blocks.append(
202
+ nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
203
+ padding=1))
204
+ self.up_blocks = nn.Sequential(*up_blocks)
205
+ self.out_block = nn.Sequential(
206
+ nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
207
+ nn.PixelShuffle(2),
208
+ )
209
+
210
+ def encode(self, x, quantize=False):
211
+ x = self.in_block(x)
212
+ x = self.down_blocks(x)
213
+ if quantize:
214
+ qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
215
+ return qe, x, indices, vq_loss + commit_loss * 0.25
216
+ else:
217
+ return x
218
+
219
+ def decode(self, x):
220
+ x = self.up_blocks(x)
221
+ x = self.out_block(x)
222
+ return x
223
+
224
+ def forward(self, x, quantize=False):
225
+ qe, x, _, vq_loss = self.encode(x, quantize)
226
+ x = self.decode(qe)
227
+ return x, vq_loss
228
+
229
+
230
+ class Discriminator(nn.Module):
231
+ def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
232
+ super().__init__()
233
+ d = max(depth - 3, 3)
234
+ layers = [
235
+ nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
236
+ nn.LeakyReLU(0.2),
237
+ ]
238
+ for i in range(depth - 1):
239
+ c_in = c_hidden // (2 ** max((d - i), 0))
240
+ c_out = c_hidden // (2 ** max((d - 1 - i), 0))
241
+ layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
242
+ layers.append(nn.InstanceNorm2d(c_out))
243
+ layers.append(nn.LeakyReLU(0.2))
244
+ self.encoder = nn.Sequential(*layers)
245
+ self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
246
+ self.logits = nn.Sigmoid()
247
+
248
+ def forward(self, x, cond=None):
249
+ x = self.encoder(x)
250
+ if cond is not None:
251
+ cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
252
+ x = torch.cat([x, cond], dim=1)
253
+ x = self.shuffle(x)
254
+ x = self.logits(x)
255
+ return x
content/flux/totoro/ldm/cascade/stage_b.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of totoroUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import math
20
+ import torch
21
+ from torch import nn
22
+ from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
23
+
24
+ class StageB(nn.Module):
25
+ def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280],
26
+ nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]],
27
+ block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280,
28
+ c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.0, 0.0], self_attn=True,
29
+ t_conds=['sca'], stable_cascade_stage=None, dtype=None, device=None, operations=None):
30
+ super().__init__()
31
+ self.dtype = dtype
32
+ self.c_r = c_r
33
+ self.t_conds = t_conds
34
+ self.c_clip_seq = c_clip_seq
35
+ if not isinstance(dropout, list):
36
+ dropout = [dropout] * len(c_hidden)
37
+ if not isinstance(self_attn, list):
38
+ self_attn = [self_attn] * len(c_hidden)
39
+
40
+ # CONDITIONING
41
+ self.effnet_mapper = nn.Sequential(
42
+ operations.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
43
+ nn.GELU(),
44
+ operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
45
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
46
+ )
47
+ self.pixels_mapper = nn.Sequential(
48
+ operations.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
49
+ nn.GELU(),
50
+ operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
51
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
52
+ )
53
+ self.clip_mapper = operations.Linear(c_clip, c_cond * c_clip_seq, dtype=dtype, device=device)
54
+ self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
55
+
56
+ self.embedding = nn.Sequential(
57
+ nn.PixelUnshuffle(patch_size),
58
+ operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
59
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
60
+ )
61
+
62
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
63
+ if block_type == 'C':
64
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
65
+ elif block_type == 'A':
66
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
67
+ elif block_type == 'F':
68
+ return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
69
+ elif block_type == 'T':
70
+ return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
71
+ else:
72
+ raise Exception(f'Block type {block_type} not supported')
73
+
74
+ # BLOCKS
75
+ # -- down blocks
76
+ self.down_blocks = nn.ModuleList()
77
+ self.down_downscalers = nn.ModuleList()
78
+ self.down_repeat_mappers = nn.ModuleList()
79
+ for i in range(len(c_hidden)):
80
+ if i > 0:
81
+ self.down_downscalers.append(nn.Sequential(
82
+ LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
83
+ operations.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2, dtype=dtype, device=device),
84
+ ))
85
+ else:
86
+ self.down_downscalers.append(nn.Identity())
87
+ down_block = nn.ModuleList()
88
+ for _ in range(blocks[0][i]):
89
+ for block_type in level_config[i]:
90
+ block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
91
+ down_block.append(block)
92
+ self.down_blocks.append(down_block)
93
+ if block_repeat is not None:
94
+ block_repeat_mappers = nn.ModuleList()
95
+ for _ in range(block_repeat[0][i] - 1):
96
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
97
+ self.down_repeat_mappers.append(block_repeat_mappers)
98
+
99
+ # -- up blocks
100
+ self.up_blocks = nn.ModuleList()
101
+ self.up_upscalers = nn.ModuleList()
102
+ self.up_repeat_mappers = nn.ModuleList()
103
+ for i in reversed(range(len(c_hidden))):
104
+ if i > 0:
105
+ self.up_upscalers.append(nn.Sequential(
106
+ LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
107
+ operations.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2, dtype=dtype, device=device),
108
+ ))
109
+ else:
110
+ self.up_upscalers.append(nn.Identity())
111
+ up_block = nn.ModuleList()
112
+ for j in range(blocks[1][::-1][i]):
113
+ for k, block_type in enumerate(level_config[i]):
114
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
115
+ block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
116
+ self_attn=self_attn[i])
117
+ up_block.append(block)
118
+ self.up_blocks.append(up_block)
119
+ if block_repeat is not None:
120
+ block_repeat_mappers = nn.ModuleList()
121
+ for _ in range(block_repeat[1][::-1][i] - 1):
122
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
123
+ self.up_repeat_mappers.append(block_repeat_mappers)
124
+
125
+ # OUTPUT
126
+ self.clf = nn.Sequential(
127
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
128
+ operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
129
+ nn.PixelShuffle(patch_size),
130
+ )
131
+
132
+ # --- WEIGHT INIT ---
133
+ # self.apply(self._init_weights) # General init
134
+ # nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
135
+ # nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
136
+ # nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
137
+ # nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
138
+ # nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
139
+ # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
140
+ # nn.init.constant_(self.clf[1].weight, 0) # outputs
141
+ #
142
+ # # blocks
143
+ # for level_block in self.down_blocks + self.up_blocks:
144
+ # for block in level_block:
145
+ # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
146
+ # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
147
+ # elif isinstance(block, TimestepBlock):
148
+ # for layer in block.modules():
149
+ # if isinstance(layer, nn.Linear):
150
+ # nn.init.constant_(layer.weight, 0)
151
+ #
152
+ # def _init_weights(self, m):
153
+ # if isinstance(m, (nn.Conv2d, nn.Linear)):
154
+ # torch.nn.init.xavier_uniform_(m.weight)
155
+ # if m.bias is not None:
156
+ # nn.init.constant_(m.bias, 0)
157
+
158
+ def gen_r_embedding(self, r, max_positions=10000):
159
+ r = r * max_positions
160
+ half_dim = self.c_r // 2
161
+ emb = math.log(max_positions) / (half_dim - 1)
162
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
163
+ emb = r[:, None] * emb[None, :]
164
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
165
+ if self.c_r % 2 == 1: # zero pad
166
+ emb = nn.functional.pad(emb, (0, 1), mode='constant')
167
+ return emb
168
+
169
+ def gen_c_embeddings(self, clip):
170
+ if len(clip.shape) == 2:
171
+ clip = clip.unsqueeze(1)
172
+ clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
173
+ clip = self.clip_norm(clip)
174
+ return clip
175
+
176
+ def _down_encode(self, x, r_embed, clip):
177
+ level_outputs = []
178
+ block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
179
+ for down_block, downscaler, repmap in block_group:
180
+ x = downscaler(x)
181
+ for i in range(len(repmap) + 1):
182
+ for block in down_block:
183
+ if isinstance(block, ResBlock) or (
184
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
185
+ ResBlock)):
186
+ x = block(x)
187
+ elif isinstance(block, AttnBlock) or (
188
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
189
+ AttnBlock)):
190
+ x = block(x, clip)
191
+ elif isinstance(block, TimestepBlock) or (
192
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
193
+ TimestepBlock)):
194
+ x = block(x, r_embed)
195
+ else:
196
+ x = block(x)
197
+ if i < len(repmap):
198
+ x = repmap[i](x)
199
+ level_outputs.insert(0, x)
200
+ return level_outputs
201
+
202
+ def _up_decode(self, level_outputs, r_embed, clip):
203
+ x = level_outputs[0]
204
+ block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
205
+ for i, (up_block, upscaler, repmap) in enumerate(block_group):
206
+ for j in range(len(repmap) + 1):
207
+ for k, block in enumerate(up_block):
208
+ if isinstance(block, ResBlock) or (
209
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
210
+ ResBlock)):
211
+ skip = level_outputs[i] if k == 0 and i > 0 else None
212
+ if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
213
+ x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
214
+ align_corners=True)
215
+ x = block(x, skip)
216
+ elif isinstance(block, AttnBlock) or (
217
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
218
+ AttnBlock)):
219
+ x = block(x, clip)
220
+ elif isinstance(block, TimestepBlock) or (
221
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
222
+ TimestepBlock)):
223
+ x = block(x, r_embed)
224
+ else:
225
+ x = block(x)
226
+ if j < len(repmap):
227
+ x = repmap[j](x)
228
+ x = upscaler(x)
229
+ return x
230
+
231
+ def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
232
+ if pixels is None:
233
+ pixels = x.new_zeros(x.size(0), 3, 8, 8)
234
+
235
+ # Process the conditioning embeddings
236
+ r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
237
+ for c in self.t_conds:
238
+ t_cond = kwargs.get(c, torch.zeros_like(r))
239
+ r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
240
+ clip = self.gen_c_embeddings(clip)
241
+
242
+ # Model Blocks
243
+ x = self.embedding(x)
244
+ x = x + self.effnet_mapper(
245
+ nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
246
+ x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
247
+ align_corners=True)
248
+ level_outputs = self._down_encode(x, r_embed, clip)
249
+ x = self._up_decode(level_outputs, r_embed, clip)
250
+ return self.clf(x)
251
+
252
+ def update_weights_ema(self, src_model, beta=0.999):
253
+ for self_params, src_params in zip(self.parameters(), src_model.parameters()):
254
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
255
+ for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
256
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
content/flux/totoro/ldm/cascade/stage_c.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of totoroUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import torch
20
+ from torch import nn
21
+ import math
22
+ from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
23
+ # from .controlnet import ControlNetDeliverer
24
+
25
+ class UpDownBlock2d(nn.Module):
26
+ def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None):
27
+ super().__init__()
28
+ assert mode in ['up', 'down']
29
+ interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear',
30
+ align_corners=True) if enabled else nn.Identity()
31
+ mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device)
32
+ self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation])
33
+
34
+ def forward(self, x):
35
+ for block in self.blocks:
36
+ x = block(x)
37
+ return x
38
+
39
+
40
+ class StageC(nn.Module):
41
+ def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32],
42
+ blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'],
43
+ c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3,
44
+ dropout=[0.0, 0.0], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None,
45
+ dtype=None, device=None, operations=None):
46
+ super().__init__()
47
+ self.dtype = dtype
48
+ self.c_r = c_r
49
+ self.t_conds = t_conds
50
+ self.c_clip_seq = c_clip_seq
51
+ if not isinstance(dropout, list):
52
+ dropout = [dropout] * len(c_hidden)
53
+ if not isinstance(self_attn, list):
54
+ self_attn = [self_attn] * len(c_hidden)
55
+
56
+ # CONDITIONING
57
+ self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device)
58
+ self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device)
59
+ self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device)
60
+ self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
61
+
62
+ self.embedding = nn.Sequential(
63
+ nn.PixelUnshuffle(patch_size),
64
+ operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
65
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6)
66
+ )
67
+
68
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
69
+ if block_type == 'C':
70
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
71
+ elif block_type == 'A':
72
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
73
+ elif block_type == 'F':
74
+ return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
75
+ elif block_type == 'T':
76
+ return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
77
+ else:
78
+ raise Exception(f'Block type {block_type} not supported')
79
+
80
+ # BLOCKS
81
+ # -- down blocks
82
+ self.down_blocks = nn.ModuleList()
83
+ self.down_downscalers = nn.ModuleList()
84
+ self.down_repeat_mappers = nn.ModuleList()
85
+ for i in range(len(c_hidden)):
86
+ if i > 0:
87
+ self.down_downscalers.append(nn.Sequential(
88
+ LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
89
+ UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
90
+ ))
91
+ else:
92
+ self.down_downscalers.append(nn.Identity())
93
+ down_block = nn.ModuleList()
94
+ for _ in range(blocks[0][i]):
95
+ for block_type in level_config[i]:
96
+ block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
97
+ down_block.append(block)
98
+ self.down_blocks.append(down_block)
99
+ if block_repeat is not None:
100
+ block_repeat_mappers = nn.ModuleList()
101
+ for _ in range(block_repeat[0][i] - 1):
102
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
103
+ self.down_repeat_mappers.append(block_repeat_mappers)
104
+
105
+ # -- up blocks
106
+ self.up_blocks = nn.ModuleList()
107
+ self.up_upscalers = nn.ModuleList()
108
+ self.up_repeat_mappers = nn.ModuleList()
109
+ for i in reversed(range(len(c_hidden))):
110
+ if i > 0:
111
+ self.up_upscalers.append(nn.Sequential(
112
+ LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6),
113
+ UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
114
+ ))
115
+ else:
116
+ self.up_upscalers.append(nn.Identity())
117
+ up_block = nn.ModuleList()
118
+ for j in range(blocks[1][::-1][i]):
119
+ for k, block_type in enumerate(level_config[i]):
120
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
121
+ block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
122
+ self_attn=self_attn[i])
123
+ up_block.append(block)
124
+ self.up_blocks.append(up_block)
125
+ if block_repeat is not None:
126
+ block_repeat_mappers = nn.ModuleList()
127
+ for _ in range(block_repeat[1][::-1][i] - 1):
128
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
129
+ self.up_repeat_mappers.append(block_repeat_mappers)
130
+
131
+ # OUTPUT
132
+ self.clf = nn.Sequential(
133
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
134
+ operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
135
+ nn.PixelShuffle(patch_size),
136
+ )
137
+
138
+ # --- WEIGHT INIT ---
139
+ # self.apply(self._init_weights) # General init
140
+ # nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
141
+ # nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
142
+ # nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
143
+ # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
144
+ # nn.init.constant_(self.clf[1].weight, 0) # outputs
145
+ #
146
+ # # blocks
147
+ # for level_block in self.down_blocks + self.up_blocks:
148
+ # for block in level_block:
149
+ # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
150
+ # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
151
+ # elif isinstance(block, TimestepBlock):
152
+ # for layer in block.modules():
153
+ # if isinstance(layer, nn.Linear):
154
+ # nn.init.constant_(layer.weight, 0)
155
+ #
156
+ # def _init_weights(self, m):
157
+ # if isinstance(m, (nn.Conv2d, nn.Linear)):
158
+ # torch.nn.init.xavier_uniform_(m.weight)
159
+ # if m.bias is not None:
160
+ # nn.init.constant_(m.bias, 0)
161
+
162
+ def gen_r_embedding(self, r, max_positions=10000):
163
+ r = r * max_positions
164
+ half_dim = self.c_r // 2
165
+ emb = math.log(max_positions) / (half_dim - 1)
166
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
167
+ emb = r[:, None] * emb[None, :]
168
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
169
+ if self.c_r % 2 == 1: # zero pad
170
+ emb = nn.functional.pad(emb, (0, 1), mode='constant')
171
+ return emb
172
+
173
+ def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
174
+ clip_txt = self.clip_txt_mapper(clip_txt)
175
+ if len(clip_txt_pooled.shape) == 2:
176
+ clip_txt_pooled = clip_txt_pooled.unsqueeze(1)
177
+ if len(clip_img.shape) == 2:
178
+ clip_img = clip_img.unsqueeze(1)
179
+ clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
180
+ clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
181
+ clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
182
+ clip = self.clip_norm(clip)
183
+ return clip
184
+
185
+ def _down_encode(self, x, r_embed, clip, cnet=None):
186
+ level_outputs = []
187
+ block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
188
+ for down_block, downscaler, repmap in block_group:
189
+ x = downscaler(x)
190
+ for i in range(len(repmap) + 1):
191
+ for block in down_block:
192
+ if isinstance(block, ResBlock) or (
193
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
194
+ ResBlock)):
195
+ if cnet is not None:
196
+ next_cnet = cnet.pop()
197
+ if next_cnet is not None:
198
+ x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
199
+ align_corners=True).to(x.dtype)
200
+ x = block(x)
201
+ elif isinstance(block, AttnBlock) or (
202
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
203
+ AttnBlock)):
204
+ x = block(x, clip)
205
+ elif isinstance(block, TimestepBlock) or (
206
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
207
+ TimestepBlock)):
208
+ x = block(x, r_embed)
209
+ else:
210
+ x = block(x)
211
+ if i < len(repmap):
212
+ x = repmap[i](x)
213
+ level_outputs.insert(0, x)
214
+ return level_outputs
215
+
216
+ def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
217
+ x = level_outputs[0]
218
+ block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
219
+ for i, (up_block, upscaler, repmap) in enumerate(block_group):
220
+ for j in range(len(repmap) + 1):
221
+ for k, block in enumerate(up_block):
222
+ if isinstance(block, ResBlock) or (
223
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
224
+ ResBlock)):
225
+ skip = level_outputs[i] if k == 0 and i > 0 else None
226
+ if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
227
+ x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
228
+ align_corners=True)
229
+ if cnet is not None:
230
+ next_cnet = cnet.pop()
231
+ if next_cnet is not None:
232
+ x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
233
+ align_corners=True).to(x.dtype)
234
+ x = block(x, skip)
235
+ elif isinstance(block, AttnBlock) or (
236
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
237
+ AttnBlock)):
238
+ x = block(x, clip)
239
+ elif isinstance(block, TimestepBlock) or (
240
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
241
+ TimestepBlock)):
242
+ x = block(x, r_embed)
243
+ else:
244
+ x = block(x)
245
+ if j < len(repmap):
246
+ x = repmap[j](x)
247
+ x = upscaler(x)
248
+ return x
249
+
250
+ def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
251
+ # Process the conditioning embeddings
252
+ r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
253
+ for c in self.t_conds:
254
+ t_cond = kwargs.get(c, torch.zeros_like(r))
255
+ r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
256
+ clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
257
+
258
+ if control is not None:
259
+ cnet = control.get("input")
260
+ else:
261
+ cnet = None
262
+
263
+ # Model Blocks
264
+ x = self.embedding(x)
265
+ level_outputs = self._down_encode(x, r_embed, clip, cnet)
266
+ x = self._up_decode(level_outputs, r_embed, clip, cnet)
267
+ return self.clf(x)
268
+
269
+ def update_weights_ema(self, src_model, beta=0.999):
270
+ for self_params, src_params in zip(self.parameters(), src_model.parameters()):
271
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
272
+ for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
273
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
content/flux/totoro/ldm/cascade/stage_c_coder.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of totoroUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+ import torch
19
+ import torchvision
20
+ from torch import nn
21
+
22
+
23
+ # EfficientNet
24
+ class EfficientNetEncoder(nn.Module):
25
+ def __init__(self, c_latent=16):
26
+ super().__init__()
27
+ self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
28
+ self.mapper = nn.Sequential(
29
+ nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
30
+ nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
31
+ )
32
+ self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
33
+ self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225]))
34
+
35
+ def forward(self, x):
36
+ x = x * 0.5 + 0.5
37
+ x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
38
+ o = self.mapper(self.backbone(x))
39
+ return o
40
+
41
+
42
+ # Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
43
+ class Previewer(nn.Module):
44
+ def __init__(self, c_in=16, c_hidden=512, c_out=3):
45
+ super().__init__()
46
+ self.blocks = nn.Sequential(
47
+ nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
48
+ nn.GELU(),
49
+ nn.BatchNorm2d(c_hidden),
50
+
51
+ nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
52
+ nn.GELU(),
53
+ nn.BatchNorm2d(c_hidden),
54
+
55
+ nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
56
+ nn.GELU(),
57
+ nn.BatchNorm2d(c_hidden // 2),
58
+
59
+ nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
60
+ nn.GELU(),
61
+ nn.BatchNorm2d(c_hidden // 2),
62
+
63
+ nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
64
+ nn.GELU(),
65
+ nn.BatchNorm2d(c_hidden // 4),
66
+
67
+ nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
68
+ nn.GELU(),
69
+ nn.BatchNorm2d(c_hidden // 4),
70
+
71
+ nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
72
+ nn.GELU(),
73
+ nn.BatchNorm2d(c_hidden // 4),
74
+
75
+ nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
76
+ nn.GELU(),
77
+ nn.BatchNorm2d(c_hidden // 4),
78
+
79
+ nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
80
+ )
81
+
82
+ def forward(self, x):
83
+ return (self.blocks(x) - 0.5) * 2.0
84
+
85
+ class StageC_coder(nn.Module):
86
+ def __init__(self):
87
+ super().__init__()
88
+ self.previewer = Previewer()
89
+ self.encoder = EfficientNetEncoder()
90
+
91
+ def encode(self, x):
92
+ return self.encoder(x)
93
+
94
+ def decode(self, x):
95
+ return self.previewer(x)
content/flux/totoro/ldm/flux/layers.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from torch import Tensor, nn
7
+
8
+ from .math import attention, rope
9
+ import totoro.ops
10
+
11
+ class EmbedND(nn.Module):
12
+ def __init__(self, dim: int, theta: int, axes_dim: list):
13
+ super().__init__()
14
+ self.dim = dim
15
+ self.theta = theta
16
+ self.axes_dim = axes_dim
17
+
18
+ def forward(self, ids: Tensor) -> Tensor:
19
+ n_axes = ids.shape[-1]
20
+ emb = torch.cat(
21
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
22
+ dim=-3,
23
+ )
24
+
25
+ return emb.unsqueeze(1)
26
+
27
+
28
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
29
+ """
30
+ Create sinusoidal timestep embeddings.
31
+ :param t: a 1-D Tensor of N indices, one per batch element.
32
+ These may be fractional.
33
+ :param dim: the dimension of the output.
34
+ :param max_period: controls the minimum frequency of the embeddings.
35
+ :return: an (N, D) Tensor of positional embeddings.
36
+ """
37
+ t = time_factor * t
38
+ half = dim // 2
39
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
40
+ t.device
41
+ )
42
+
43
+ args = t[:, None].float() * freqs[None]
44
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
45
+ if dim % 2:
46
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
47
+ if torch.is_floating_point(t):
48
+ embedding = embedding.to(t)
49
+ return embedding
50
+
51
+
52
+ class MLPEmbedder(nn.Module):
53
+ def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
54
+ super().__init__()
55
+ self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
56
+ self.silu = nn.SiLU()
57
+ self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
58
+
59
+ def forward(self, x: Tensor) -> Tensor:
60
+ return self.out_layer(self.silu(self.in_layer(x)))
61
+
62
+
63
+ class RMSNorm(torch.nn.Module):
64
+ def __init__(self, dim: int, dtype=None, device=None, operations=None):
65
+ super().__init__()
66
+ self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
67
+
68
+ def forward(self, x: Tensor):
69
+ x_dtype = x.dtype
70
+ x = x.float()
71
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
72
+ return (x * rrms).to(dtype=x_dtype) * totoro.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
73
+
74
+
75
+ class QKNorm(torch.nn.Module):
76
+ def __init__(self, dim: int, dtype=None, device=None, operations=None):
77
+ super().__init__()
78
+ self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
79
+ self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
80
+
81
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
82
+ q = self.query_norm(q)
83
+ k = self.key_norm(k)
84
+ return q.to(v), k.to(v)
85
+
86
+
87
+ class SelfAttention(nn.Module):
88
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
89
+ super().__init__()
90
+ self.num_heads = num_heads
91
+ head_dim = dim // num_heads
92
+
93
+ self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
94
+ self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
95
+ self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
96
+
97
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
98
+ qkv = self.qkv(x)
99
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
100
+ q, k = self.norm(q, k, v)
101
+ x = attention(q, k, v, pe=pe)
102
+ x = self.proj(x)
103
+ return x
104
+
105
+
106
+ @dataclass
107
+ class ModulationOut:
108
+ shift: Tensor
109
+ scale: Tensor
110
+ gate: Tensor
111
+
112
+
113
+ class Modulation(nn.Module):
114
+ def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
115
+ super().__init__()
116
+ self.is_double = double
117
+ self.multiplier = 6 if double else 3
118
+ self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
119
+
120
+ def forward(self, vec: Tensor) -> tuple:
121
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
122
+
123
+ return (
124
+ ModulationOut(*out[:3]),
125
+ ModulationOut(*out[3:]) if self.is_double else None,
126
+ )
127
+
128
+
129
+ class DoubleStreamBlock(nn.Module):
130
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None, operations=None):
131
+ super().__init__()
132
+
133
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
134
+ self.num_heads = num_heads
135
+ self.hidden_size = hidden_size
136
+ self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
137
+ self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
138
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
139
+
140
+ self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
141
+ self.img_mlp = nn.Sequential(
142
+ operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
143
+ nn.GELU(approximate="tanh"),
144
+ operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
145
+ )
146
+
147
+ self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
148
+ self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
149
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
150
+
151
+ self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
152
+ self.txt_mlp = nn.Sequential(
153
+ operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
154
+ nn.GELU(approximate="tanh"),
155
+ operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
156
+ )
157
+
158
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
159
+ img_mod1, img_mod2 = self.img_mod(vec)
160
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
161
+
162
+ # prepare image for attention
163
+ img_modulated = self.img_norm1(img)
164
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
165
+ img_qkv = self.img_attn.qkv(img_modulated)
166
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
167
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
168
+
169
+ # prepare txt for attention
170
+ txt_modulated = self.txt_norm1(txt)
171
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
172
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
173
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
174
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
175
+
176
+ # run actual attention
177
+ q = torch.cat((txt_q, img_q), dim=2)
178
+ k = torch.cat((txt_k, img_k), dim=2)
179
+ v = torch.cat((txt_v, img_v), dim=2)
180
+
181
+ attn = attention(q, k, v, pe=pe)
182
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
183
+
184
+ # calculate the img bloks
185
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
186
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
187
+
188
+ # calculate the txt bloks
189
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
190
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
191
+ return img, txt
192
+
193
+
194
+ class SingleStreamBlock(nn.Module):
195
+ """
196
+ A DiT block with parallel linear layers as described in
197
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
198
+ """
199
+
200
+ def __init__(
201
+ self,
202
+ hidden_size: int,
203
+ num_heads: int,
204
+ mlp_ratio: float = 4.0,
205
+ qk_scale: float = None,
206
+ dtype=None,
207
+ device=None,
208
+ operations=None
209
+ ):
210
+ super().__init__()
211
+ self.hidden_dim = hidden_size
212
+ self.num_heads = num_heads
213
+ head_dim = hidden_size // num_heads
214
+ self.scale = qk_scale or head_dim**-0.5
215
+
216
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
217
+ # qkv and mlp_in
218
+ self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
219
+ # proj and mlp_out
220
+ self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
221
+
222
+ self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
223
+
224
+ self.hidden_size = hidden_size
225
+ self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
226
+
227
+ self.mlp_act = nn.GELU(approximate="tanh")
228
+ self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
229
+
230
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
231
+ mod, _ = self.modulation(vec)
232
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
233
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
234
+
235
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
236
+ q, k = self.norm(q, k, v)
237
+
238
+ # compute attention
239
+ attn = attention(q, k, v, pe=pe)
240
+ # compute activation in mlp stream, cat again and run second linear layer
241
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
242
+ return x + mod.gate * output
243
+
244
+
245
+ class LastLayer(nn.Module):
246
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
247
+ super().__init__()
248
+ self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
249
+ self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
250
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
251
+
252
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
253
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
254
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
255
+ x = self.linear(x)
256
+ return x
content/flux/totoro/ldm/flux/math.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import Tensor
4
+ from totoro.ldm.modules.attention import optimized_attention
5
+ import totoro.model_management
6
+
7
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
8
+ q, k = apply_rope(q, k, pe)
9
+
10
+ heads = q.shape[1]
11
+ x = optimized_attention(q, k, v, heads, skip_reshape=True)
12
+ return x
13
+
14
+
15
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
16
+ assert dim % 2 == 0
17
+ if totoro.model_management.is_device_mps(pos.device):
18
+ device = torch.device("cpu")
19
+ else:
20
+ device = pos.device
21
+
22
+ scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
23
+ omega = 1.0 / (theta**scale)
24
+ out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
25
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
26
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
27
+ return out.to(dtype=torch.float32, device=pos.device)
28
+
29
+
30
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
31
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
32
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
33
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
34
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
35
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
content/flux/totoro/ldm/flux/model.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Original code can be found on: https://github.com/black-forest-labs/flux
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+ from .layers import (
9
+ DoubleStreamBlock,
10
+ EmbedND,
11
+ LastLayer,
12
+ MLPEmbedder,
13
+ SingleStreamBlock,
14
+ timestep_embedding,
15
+ )
16
+
17
+ from einops import rearrange, repeat
18
+
19
+ @dataclass
20
+ class FluxParams:
21
+ in_channels: int
22
+ vec_in_dim: int
23
+ context_in_dim: int
24
+ hidden_size: int
25
+ mlp_ratio: float
26
+ num_heads: int
27
+ depth: int
28
+ depth_single_blocks: int
29
+ axes_dim: list
30
+ theta: int
31
+ qkv_bias: bool
32
+ guidance_embed: bool
33
+
34
+
35
+ class Flux(nn.Module):
36
+ """
37
+ Transformer model for flow matching on sequences.
38
+ """
39
+
40
+ def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
41
+ super().__init__()
42
+ self.dtype = dtype
43
+ params = FluxParams(**kwargs)
44
+ self.params = params
45
+ self.in_channels = params.in_channels
46
+ self.out_channels = self.in_channels
47
+ if params.hidden_size % params.num_heads != 0:
48
+ raise ValueError(
49
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
50
+ )
51
+ pe_dim = params.hidden_size // params.num_heads
52
+ if sum(params.axes_dim) != pe_dim:
53
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
54
+ self.hidden_size = params.hidden_size
55
+ self.num_heads = params.num_heads
56
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
57
+ self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
58
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
59
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
60
+ self.guidance_in = (
61
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
62
+ )
63
+ self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
64
+
65
+ self.double_blocks = nn.ModuleList(
66
+ [
67
+ DoubleStreamBlock(
68
+ self.hidden_size,
69
+ self.num_heads,
70
+ mlp_ratio=params.mlp_ratio,
71
+ qkv_bias=params.qkv_bias,
72
+ dtype=dtype, device=device, operations=operations
73
+ )
74
+ for _ in range(params.depth)
75
+ ]
76
+ )
77
+
78
+ self.single_blocks = nn.ModuleList(
79
+ [
80
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
81
+ for _ in range(params.depth_single_blocks)
82
+ ]
83
+ )
84
+
85
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
86
+
87
+ def forward_orig(
88
+ self,
89
+ img: Tensor,
90
+ img_ids: Tensor,
91
+ txt: Tensor,
92
+ txt_ids: Tensor,
93
+ timesteps: Tensor,
94
+ y: Tensor,
95
+ guidance: Tensor = None,
96
+ ) -> Tensor:
97
+ if img.ndim != 3 or txt.ndim != 3:
98
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
99
+
100
+ # running on sequences img
101
+ img = self.img_in(img)
102
+ vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
103
+ if self.params.guidance_embed:
104
+ if guidance is None:
105
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
106
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
107
+
108
+ vec = vec + self.vector_in(y)
109
+ txt = self.txt_in(txt)
110
+
111
+ ids = torch.cat((txt_ids, img_ids), dim=1)
112
+ pe = self.pe_embedder(ids)
113
+
114
+ for block in self.double_blocks:
115
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
116
+
117
+ img = torch.cat((txt, img), 1)
118
+ for block in self.single_blocks:
119
+ img = block(img, vec=vec, pe=pe)
120
+ img = img[:, txt.shape[1] :, ...]
121
+
122
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
123
+ return img
124
+
125
+ def forward(self, x, timestep, context, y, guidance, **kwargs):
126
+ bs, c, h, w = x.shape
127
+ img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
128
+
129
+ h_len = (h // 2)
130
+ w_len = (w // 2)
131
+ img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
132
+ img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
133
+ img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
134
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
135
+
136
+ txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
137
+ out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
138
+ return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)