John6666 commited on
Commit
aefd7f3
·
verified ·
1 Parent(s): 6a84530

Upload 16 files

Browse files
README.md CHANGED
@@ -1,13 +1,13 @@
1
- ---
2
- title: T2I test
3
- emoji: 🖼
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.50.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: T2I test
3
+ emoji: 🖼
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 6.1.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,120 +1,118 @@
1
- import spaces
2
- import gradio as gr
3
- from gradio_huggingfacehub_search import HuggingfaceHubSearch
4
- from t2i.infer import (infer, infer_multi, infer_simple, save_image_history, save_gallery_history,
5
- update_param_mode_gr, update_ar_gr,
6
- MAX_SEED, MAX_IMAGE_SIZE, ASPECT_RATIOS, FILE_FORMATS, DEFAULT_TASKS, DEFAULT_DURATION,
7
- DEFAULT_I2I_STRENGTH, DEFAULT_UPSCALE_STRENGTH, DEFAULT_UPSCALE_BY, DEFAULT_CLIP_SKIP,
8
- models, MODEL_TYPES, SAMPLER_NAMES, PRED_TYPES, VAE_NAMES,
9
- UPSCALE_MODES, PARAM_MODES, PIPELINE_TYPES)
10
-
11
- css = """
12
- #col-container {
13
- margin: 0 auto;
14
- max-width: 1080px;
15
- }
16
- """
17
-
18
- with gr.Blocks(fill_height=True, fill_width=True, css=css) as demo:
19
- with gr.Tab("Image Generator"):
20
- lora_dict = gr.State({})
21
- with gr.Column(elem_id="col-container"):
22
- with gr.Tab("Normal"):
23
- with gr.Row():
24
- prompt = gr.Text(label="Prompt", show_label=False, lines=1, placeholder="Enter your prompt", container=False)
25
- run_button = gr.Button("Run", scale=0)
26
- run_button_simple = gr.Button("Simple", scale=0, visible=False) # for API
27
- result = gr.Image(label="Result", show_label=False, format="png", type="filepath", interactive=False, show_share_button=False, show_download_button=True)
28
-
29
- with gr.Tab("Multi"):
30
- with gr.Row():
31
- prompt_multi = gr.Text(label="Prompt", show_label=False, lines=1, placeholder="Enter your prompt", container=False)
32
- run_button_multi = gr.Button("Run", scale=0)
33
- model_name_multi = gr.Dropdown(label="Model", choices=models, value=models[0], multiselect=True, allow_custom_value=True)
34
- num_images = gr.Slider(label="Count", minimum=1, maximum=16, step=1, value=1)
35
- result_multi = gr.Gallery(label="Result", columns=2, object_fit="contain", format="png", interactive=False, show_share_button=False, show_download_button=True)
36
-
37
- with gr.Accordion("Output History", open=False):
38
- history_files = gr.Files(interactive=False, visible=False)
39
- history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", format="png", interactive=False, show_share_button=False,
40
- show_download_button=True)
41
- history_clear_button = gr.Button(value="Clear History", variant="secondary")
42
- history_clear_button.click(lambda: ([], []), None, [history_gallery, history_files], queue=False, show_api=False)
43
-
44
- with gr.Group():
45
- negative_prompt = gr.Text(label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt",
46
- value="") # nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn
47
- with gr.Row(equal_height=True):
48
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
49
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
50
- with gr.Row(equal_height=True):
51
- param_mode = gr.Radio(label="Parameter Settings", choices=PARAM_MODES, value=PARAM_MODES[0])
52
- ar = gr.Dropdown(label="Aspect Ratio", choices=ASPECT_RATIOS, value=ASPECT_RATIOS[0])
53
- with gr.Row(equal_height=True):
54
- width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, visible=False)
55
- height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, visible=False)
56
- guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=20.0, step=0.1, value=7, visible=False)
57
- num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=60, step=1, value=28, visible=False)
58
- with gr.Group():
59
- model_name = gr.Dropdown(label="Model", choices=models, value=models[0], allow_custom_value=True)
60
- with gr.Accordion("Advanced Settings", open=False):
61
- with gr.Row(equal_height=True):
62
- model_type = gr.Dropdown(label="Model Type", choices=MODEL_TYPES, value=MODEL_TYPES[0])
63
- vae = gr.Dropdown(label="VAE", choices=VAE_NAMES, value=VAE_NAMES[0], allow_custom_value=True)
64
- with gr.Row(equal_height=True):
65
- sampler = gr.Dropdown(label="Sampler", choices=SAMPLER_NAMES, value=SAMPLER_NAMES[0])
66
- pred_type = gr.Dropdown(label="Sampler prediction", choices=PRED_TYPES, value=PRED_TYPES[0])
67
- with gr.Row(equal_height=True):
68
- pipe_type = gr.Dropdown(label="Pipeline Type", choices=PIPELINE_TYPES, value=PIPELINE_TYPES[0])
69
- clip_skip = gr.Slider(label="Clip Skip", minimum=0, maximum=12, step=1, value=DEFAULT_CLIP_SKIP)
70
- with gr.Row(equal_height=True):
71
- task = gr.Radio(label="Task", choices=DEFAULT_TASKS, value=DEFAULT_TASKS[0])
72
- strength = gr.Slider(label="Image-to-Image / Inpainting Strength", minimum=0, maximum=1., step=0.01, value=DEFAULT_I2I_STRENGTH)
73
- input_image = gr.ImageEditor(label="Input Image", type="filepath", sources=["upload", "clipboard", "webcam"], image_mode='RGB',
74
- show_share_button=False, show_fullscreen_button=False, layers=False, canvas_size=(384, 384), width=384, height=512,
75
- brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed", default_size=32), eraser=gr.Eraser(default_size="32"))
76
- with gr.Row(equal_height=True):
77
- upscale_mode = gr.Dropdown(label="Upscaling", choices=UPSCALE_MODES, value=UPSCALE_MODES[0])
78
- upscale_strength = gr.Slider(label="Strength", minimum=0, maximum=1, step=0.05, value=DEFAULT_UPSCALE_STRENGTH)
79
- upscale_by = gr.Slider(label="Upscale by", minimum=1, maximum=1.5, step=0.1, value=DEFAULT_UPSCALE_BY)
80
- with gr.Row(equal_height=True):
81
- format = gr.Dropdown(label="Output Format", choices=FILE_FORMATS, value=FILE_FORMATS[0])
82
- gpu_duration = gr.Number(minimum=0, maximum=240, value=DEFAULT_DURATION, label="GPU time duration (seconds per image)")
83
-
84
- with gr.Tab("PNG Info"):
85
- def extract_exif_data(image):
86
- if image is None: return ""
87
- try:
88
- metadata_keys = ["parameters", "metadata", "prompt", "Comment"]
89
- for key in metadata_keys:
90
- if key in image.info:
91
- return image.info[key]
92
- return str(image.info)
93
- except Exception as e:
94
- return f"Error extracting metadata: {str(e)}"
95
- with gr.Row():
96
- with gr.Column():
97
- image_metadata = gr.Image(label="Image with metadata", type="pil", sources=["upload"])
98
- with gr.Column():
99
- result_metadata = gr.Textbox(label="Metadata", show_label=True, show_copy_button=True, interactive=False, container=True, max_lines=99)
100
- image_metadata.change(fn=extract_exif_data, inputs=[image_metadata], outputs=[result_metadata], show_api=False)
101
-
102
- gr.on(triggers=[run_button.click, prompt.submit], fn=infer,
103
- inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
104
- model_name, sampler, pred_type, vae, model_type, clip_skip, pipe_type, lora_dict, upscale_mode, upscale_strength, upscale_by,
105
- input_image, strength, param_mode, ar, format, task, gpu_duration],
106
- outputs=[result])
107
- gr.on(triggers=[run_button_multi.click, prompt_multi.submit], fn=infer_multi,
108
- inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
109
- model_name_multi, sampler, pred_type, vae, clip_skip, pipe_type, lora_dict, upscale_mode, upscale_strength, upscale_by,
110
- input_image, strength, param_mode, ar, format, num_images, task, gpu_duration],
111
- outputs=[result_multi])
112
- run_button_simple.click(fn=infer_simple, inputs=[prompt, negative_prompt, seed, randomize_seed, model_name], outputs=[result], show_api=True)
113
-
114
- result.change(save_image_history, [result, history_gallery, history_files], [history_gallery, history_files], queue=False, show_api=False)
115
- result_multi.change(save_gallery_history, [result_multi, history_gallery, history_files], [history_gallery, history_files], queue=False, show_api=False)
116
-
117
- ar.change(update_ar_gr, [ar], [width, height], queue=False, show_api=False)
118
- param_mode.change(update_param_mode_gr, [param_mode], [guidance_scale, num_inference_steps], queue=False, show_api=False)
119
-
120
- demo.queue().launch(ssr_mode=False, mcp_server=True)
 
1
+ import spaces
2
+ import gradio as gr
3
+ from gradio_huggingfacehub_search import HuggingfaceHubSearch
4
+ from t2i.infer import (infer, infer_multi, infer_simple, save_image_history, save_gallery_history,
5
+ update_param_mode_gr, update_ar_gr,
6
+ MAX_SEED, MAX_IMAGE_SIZE, ASPECT_RATIOS, FILE_FORMATS, DEFAULT_TASKS, DEFAULT_DURATION,
7
+ DEFAULT_I2I_STRENGTH, DEFAULT_UPSCALE_STRENGTH, DEFAULT_UPSCALE_BY, DEFAULT_CLIP_SKIP,
8
+ models, MODEL_TYPES, SAMPLER_NAMES, PRED_TYPES, VAE_NAMES,
9
+ UPSCALE_MODES, PARAM_MODES, PIPELINE_TYPES)
10
+
11
+ css = """
12
+ #col-container {
13
+ margin: 0 auto;
14
+ max-width: 1080px;
15
+ }
16
+ """
17
+
18
+ with gr.Blocks(fill_height=True, fill_width=True) as demo:
19
+ with gr.Tab("Image Generator"):
20
+ lora_dict = gr.State({})
21
+ with gr.Column(elem_id="col-container"):
22
+ with gr.Tab("Normal"):
23
+ with gr.Row():
24
+ prompt = gr.Text(label="Prompt", show_label=False, lines=1, placeholder="Enter your prompt", container=False)
25
+ run_button = gr.Button("Run", scale=0)
26
+ run_button_simple = gr.Button("Simple", scale=0, visible=False) # for API
27
+ result = gr.Image(label="Result", show_label=False, format="png", type="filepath", interactive=False, buttons=["download", "fullscreen"])
28
+
29
+ with gr.Tab("Multi"):
30
+ with gr.Row():
31
+ prompt_multi = gr.Text(label="Prompt", show_label=False, lines=1, placeholder="Enter your prompt", container=False)
32
+ run_button_multi = gr.Button("Run", scale=0)
33
+ model_name_multi = gr.Dropdown(label="Model", choices=models, value=models[0], multiselect=True, allow_custom_value=True)
34
+ num_images = gr.Slider(label="Count", minimum=1, maximum=16, step=1, value=1)
35
+ result_multi = gr.Gallery(label="Result", columns=2, object_fit="contain", format="png", interactive=False, buttons=["download", "fullscreen"])
36
+
37
+ with gr.Accordion("Output History", open=False):
38
+ history_files = gr.Files(interactive=False, visible=False)
39
+ history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", format="png", interactive=False, buttons=["download", "fullscreen"])
40
+ history_clear_button = gr.Button(value="Clear History", variant="secondary")
41
+ history_clear_button.click(lambda: ([], []), None, [history_gallery, history_files], queue=False, api_visibility="undocumented")
42
+
43
+ with gr.Group():
44
+ negative_prompt = gr.Text(label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt",
45
+ value="") # nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn
46
+ with gr.Row(equal_height=True):
47
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
48
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
49
+ with gr.Row(equal_height=True):
50
+ param_mode = gr.Radio(label="Parameter Settings", choices=PARAM_MODES, value=PARAM_MODES[0])
51
+ ar = gr.Dropdown(label="Aspect Ratio", choices=ASPECT_RATIOS, value=ASPECT_RATIOS[0])
52
+ with gr.Row(equal_height=True):
53
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, visible=False)
54
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, visible=False)
55
+ guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=20.0, step=0.1, value=7, visible=False)
56
+ num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=60, step=1, value=28, visible=False)
57
+ with gr.Group():
58
+ model_name = gr.Dropdown(label="Model", choices=models, value=models[0], allow_custom_value=True)
59
+ with gr.Accordion("Advanced Settings", open=False):
60
+ with gr.Row(equal_height=True):
61
+ model_type = gr.Dropdown(label="Model Type", choices=MODEL_TYPES, value=MODEL_TYPES[0])
62
+ vae = gr.Dropdown(label="VAE", choices=VAE_NAMES, value=VAE_NAMES[0], allow_custom_value=True)
63
+ with gr.Row(equal_height=True):
64
+ sampler = gr.Dropdown(label="Sampler", choices=SAMPLER_NAMES, value=SAMPLER_NAMES[0])
65
+ pred_type = gr.Dropdown(label="Sampler prediction", choices=PRED_TYPES, value=PRED_TYPES[0])
66
+ with gr.Row(equal_height=True):
67
+ pipe_type = gr.Dropdown(label="Pipeline Type", choices=PIPELINE_TYPES, value=PIPELINE_TYPES[0])
68
+ clip_skip = gr.Slider(label="Clip Skip", minimum=0, maximum=12, step=1, value=DEFAULT_CLIP_SKIP)
69
+ with gr.Row(equal_height=True):
70
+ task = gr.Radio(label="Task", choices=DEFAULT_TASKS, value=DEFAULT_TASKS[0])
71
+ strength = gr.Slider(label="Image-to-Image / Inpainting Strength", minimum=0, maximum=1., step=0.01, value=DEFAULT_I2I_STRENGTH)
72
+ input_image = gr.ImageEditor(label="Input Image", type="filepath", sources=["upload", "clipboard", "webcam"], image_mode='RGB', layers=False, buttons=[], canvas_size=(384, 384), width=384, height=512,
73
+ brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed", default_size=32), eraser=gr.Eraser(default_size="32"))
74
+ with gr.Row(equal_height=True):
75
+ upscale_mode = gr.Dropdown(label="Upscaling", choices=UPSCALE_MODES, value=UPSCALE_MODES[0])
76
+ upscale_strength = gr.Slider(label="Strength", minimum=0, maximum=1, step=0.05, value=DEFAULT_UPSCALE_STRENGTH)
77
+ upscale_by = gr.Slider(label="Upscale by", minimum=1, maximum=1.5, step=0.1, value=DEFAULT_UPSCALE_BY)
78
+ with gr.Row(equal_height=True):
79
+ format = gr.Dropdown(label="Output Format", choices=FILE_FORMATS, value=FILE_FORMATS[0])
80
+ gpu_duration = gr.Number(minimum=0, maximum=240, value=DEFAULT_DURATION, label="GPU time duration (seconds per image)")
81
+
82
+ with gr.Tab("PNG Info"):
83
+ def extract_exif_data(image):
84
+ if image is None: return ""
85
+ try:
86
+ metadata_keys = ["parameters", "metadata", "prompt", "Comment"]
87
+ for key in metadata_keys:
88
+ if key in image.info:
89
+ return image.info[key]
90
+ return str(image.info)
91
+ except Exception as e:
92
+ return f"Error extracting metadata: {str(e)}"
93
+ with gr.Row():
94
+ with gr.Column():
95
+ image_metadata = gr.Image(label="Image with metadata", type="pil", sources=["upload"])
96
+ with gr.Column():
97
+ result_metadata = gr.Textbox(label="Metadata", show_label=True, buttons=["copy"], interactive=False, container=True, max_lines=99)
98
+ image_metadata.change(fn=extract_exif_data, inputs=[image_metadata], outputs=[result_metadata], api_visibility="undocumented")
99
+
100
+ gr.on(triggers=[run_button.click, prompt.submit], fn=infer,
101
+ inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
102
+ model_name, sampler, pred_type, vae, model_type, clip_skip, pipe_type, lora_dict, upscale_mode, upscale_strength, upscale_by,
103
+ input_image, strength, param_mode, ar, format, task, gpu_duration],
104
+ outputs=[result])
105
+ gr.on(triggers=[run_button_multi.click, prompt_multi.submit], fn=infer_multi,
106
+ inputs=[prompt_multi, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
107
+ model_name_multi, sampler, pred_type, vae, clip_skip, pipe_type, lora_dict, upscale_mode, upscale_strength, upscale_by,
108
+ input_image, strength, param_mode, ar, format, num_images, task, gpu_duration],
109
+ outputs=[result_multi])
110
+ run_button_simple.click(fn=infer_simple, inputs=[prompt, negative_prompt, seed, randomize_seed, model_name], outputs=[result])
111
+
112
+ result.change(save_image_history, [result, history_gallery, history_files], [history_gallery, history_files], queue=False, api_visibility="undocumented")
113
+ result_multi.change(save_gallery_history, [result_multi, history_gallery, history_files], [history_gallery, history_files], queue=False, api_visibility="undocumented")
114
+
115
+ ar.change(update_ar_gr, [ar], [width, height], queue=False, api_visibility="undocumented")
116
+ param_mode.change(update_param_mode_gr, [param_mode], [guidance_scale, num_inference_steps], queue=False, api_visibility="undocumented")
117
+
118
+ demo.queue().launch(ssr_mode=False, mcp_server=True, css=css)
 
 
requirements.txt CHANGED
@@ -1,17 +1,18 @@
1
- huggingface_hub
2
- hf-xet
3
- torch==2.8.0
4
- #torchao
5
- torchvision
6
- accelerate
7
- diffusers
8
- transformers<=4.57.1
9
- peft
10
- invisible_watermark
11
- sentencepiece
12
- safetensors
13
- timm
14
- einops
15
- kernels
16
- gradio_huggingfacehub_search
17
- pydantic==2.10.6
 
 
1
+ huggingface_hub
2
+ hf-xet
3
+ torch==2.8.0
4
+ #torchao
5
+ torchvision
6
+ accelerate
7
+ diffusers
8
+ transformers<=4.57.1
9
+ peft
10
+ invisible_watermark
11
+ sentencepiece
12
+ safetensors
13
+ timm
14
+ einops
15
+ kernels
16
+ gradio_huggingfacehub_search
17
+ pydantic==2.10.6
18
+ opencv-python-headless
t2i/controlnet_union/guided_filter.py CHANGED
@@ -1,281 +1,281 @@
1
-
2
- # -*- coding: utf-8 -*-
3
- ## @package guided_filter.core.filters
4
- #
5
- # Implementation of guided filter.
6
- # * GuidedFilter: Original guided filter.
7
- # * FastGuidedFilter: Fast version of the guided filter.
8
- # @author tody
9
- # @date 2015/08/26
10
-
11
- import numpy as np
12
- import cv2
13
-
14
- ## Convert image into float32 type.
15
- def to32F(img):
16
- if img.dtype == np.float32:
17
- return img
18
- return (1.0 / 255.0) * np.float32(img)
19
-
20
- ## Convert image into uint8 type.
21
- def to8U(img):
22
- if img.dtype == np.uint8:
23
- return img
24
- return np.clip(np.uint8(255.0 * img), 0, 255)
25
-
26
- ## Return if the input image is gray or not.
27
- def _isGray(I):
28
- return len(I.shape) == 2
29
-
30
-
31
- ## Return down sampled image.
32
- # @param scale (w/s, h/s) image will be created.
33
- # @param shape I.shape[:2]=(h, w). numpy friendly size parameter.
34
- def _downSample(I, scale=4, shape=None):
35
- if shape is not None:
36
- h, w = shape
37
- return cv2.resize(I, (w, h), interpolation=cv2.INTER_NEAREST)
38
-
39
- h, w = I.shape[:2]
40
- return cv2.resize(I, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_NEAREST)
41
-
42
-
43
- ## Return up sampled image.
44
- # @param scale (w*s, h*s) image will be created.
45
- # @param shape I.shape[:2]=(h, w). numpy friendly size parameter.
46
- def _upSample(I, scale=2, shape=None):
47
- if shape is not None:
48
- h, w = shape
49
- return cv2.resize(I, (w, h), interpolation=cv2.INTER_LINEAR)
50
-
51
- h, w = I.shape[:2]
52
- return cv2.resize(I, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR)
53
-
54
- ## Fast guide filter.
55
- class FastGuidedFilter:
56
- ## Constructor.
57
- # @param I Input guidance image. Color or gray.
58
- # @param radius Radius of Guided Filter.
59
- # @param epsilon Regularization term of Guided Filter.
60
- # @param scale Down sampled scale.
61
- def __init__(self, I, radius=5, epsilon=0.4, scale=4):
62
- I_32F = to32F(I)
63
- self._I = I_32F
64
- h, w = I.shape[:2]
65
-
66
- I_sub = _downSample(I_32F, scale)
67
-
68
- self._I_sub = I_sub
69
- radius = int(radius / scale)
70
-
71
- if _isGray(I):
72
- self._guided_filter = GuidedFilterGray(I_sub, radius, epsilon)
73
- else:
74
- self._guided_filter = GuidedFilterColor(I_sub, radius, epsilon)
75
-
76
- ## Apply filter for the input image.
77
- # @param p Input image for the filtering.
78
- def filter(self, p):
79
- p_32F = to32F(p)
80
- shape_original = p.shape[:2]
81
-
82
- p_sub = _downSample(p_32F, shape=self._I_sub.shape[:2])
83
-
84
- if _isGray(p_sub):
85
- return self._filterGray(p_sub, shape_original)
86
-
87
- cs = p.shape[2]
88
- q = np.array(p_32F)
89
-
90
- for ci in range(cs):
91
- q[:, :, ci] = self._filterGray(p_sub[:, :, ci], shape_original)
92
- return to8U(q)
93
-
94
- def _filterGray(self, p_sub, shape_original):
95
- ab_sub = self._guided_filter._computeCoefficients(p_sub)
96
- ab = [_upSample(abi, shape=shape_original) for abi in ab_sub]
97
- return self._guided_filter._computeOutput(ab, self._I)
98
-
99
-
100
- ## Guide filter.
101
- class GuidedFilter:
102
- ## Constructor.
103
- # @param I Input guidance image. Color or gray.
104
- # @param radius Radius of Guided Filter.
105
- # @param epsilon Regularization term of Guided Filter.
106
- def __init__(self, I, radius=5, epsilon=0.4):
107
- I_32F = to32F(I)
108
-
109
- if _isGray(I):
110
- self._guided_filter = GuidedFilterGray(I_32F, radius, epsilon)
111
- else:
112
- self._guided_filter = GuidedFilterColor(I_32F, radius, epsilon)
113
-
114
- ## Apply filter for the input image.
115
- # @param p Input image for the filtering.
116
- def filter(self, p):
117
- return to8U(self._guided_filter.filter(p))
118
-
119
-
120
- ## Common parts of guided filter.
121
- #
122
- # This class is used by guided_filter class. GuidedFilterGray and GuidedFilterColor.
123
- # Based on guided_filter._computeCoefficients, guided_filter._computeOutput,
124
- # GuidedFilterCommon.filter computes filtered image for color and gray.
125
- class GuidedFilterCommon:
126
- def __init__(self, guided_filter):
127
- self._guided_filter = guided_filter
128
-
129
- ## Apply filter for the input image.
130
- # @param p Input image for the filtering.
131
- def filter(self, p):
132
- p_32F = to32F(p)
133
- if _isGray(p_32F):
134
- return self._filterGray(p_32F)
135
-
136
- cs = p.shape[2]
137
- q = np.array(p_32F)
138
-
139
- for ci in range(cs):
140
- q[:, :, ci] = self._filterGray(p_32F[:, :, ci])
141
- return q
142
-
143
- def _filterGray(self, p):
144
- ab = self._guided_filter._computeCoefficients(p)
145
- return self._guided_filter._computeOutput(ab, self._guided_filter._I)
146
-
147
-
148
- ## Guided filter for gray guidance image.
149
- class GuidedFilterGray:
150
- # @param I Input gray guidance image.
151
- # @param radius Radius of Guided Filter.
152
- # @param epsilon Regularization term of Guided Filter.
153
- def __init__(self, I, radius=5, epsilon=0.4):
154
- self._radius = 2 * radius + 1
155
- self._epsilon = epsilon
156
- self._I = to32F(I)
157
- self._initFilter()
158
- self._filter_common = GuidedFilterCommon(self)
159
-
160
- ## Apply filter for the input image.
161
- # @param p Input image for the filtering.
162
- def filter(self, p):
163
- return self._filter_common.filter(p)
164
-
165
- def _initFilter(self):
166
- I = self._I
167
- r = self._radius
168
- self._I_mean = cv2.blur(I, (r, r))
169
- I_mean_sq = cv2.blur(I ** 2, (r, r))
170
- self._I_var = I_mean_sq - self._I_mean ** 2
171
-
172
- def _computeCoefficients(self, p):
173
- r = self._radius
174
- p_mean = cv2.blur(p, (r, r))
175
- p_cov = p_mean - self._I_mean * p_mean
176
- a = p_cov / (self._I_var + self._epsilon)
177
- b = p_mean - a * self._I_mean
178
- a_mean = cv2.blur(a, (r, r))
179
- b_mean = cv2.blur(b, (r, r))
180
- return a_mean, b_mean
181
-
182
- def _computeOutput(self, ab, I):
183
- a_mean, b_mean = ab
184
- return a_mean * I + b_mean
185
-
186
-
187
- ## Guided filter for color guidance image.
188
- class GuidedFilterColor:
189
- # @param I Input color guidance image.
190
- # @param radius Radius of Guided Filter.
191
- # @param epsilon Regularization term of Guided Filter.
192
- def __init__(self, I, radius=5, epsilon=0.2):
193
- self._radius = 2 * radius + 1
194
- self._epsilon = epsilon
195
- self._I = to32F(I)
196
- self._initFilter()
197
- self._filter_common = GuidedFilterCommon(self)
198
-
199
- ## Apply filter for the input image.
200
- # @param p Input image for the filtering.
201
- def filter(self, p):
202
- return self._filter_common.filter(p)
203
-
204
- def _initFilter(self):
205
- I = self._I
206
- r = self._radius
207
- eps = self._epsilon
208
-
209
- Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
210
-
211
- self._Ir_mean = cv2.blur(Ir, (r, r))
212
- self._Ig_mean = cv2.blur(Ig, (r, r))
213
- self._Ib_mean = cv2.blur(Ib, (r, r))
214
-
215
- Irr_var = cv2.blur(Ir ** 2, (r, r)) - self._Ir_mean ** 2 + eps
216
- Irg_var = cv2.blur(Ir * Ig, (r, r)) - self._Ir_mean * self._Ig_mean
217
- Irb_var = cv2.blur(Ir * Ib, (r, r)) - self._Ir_mean * self._Ib_mean
218
- Igg_var = cv2.blur(Ig * Ig, (r, r)) - self._Ig_mean * self._Ig_mean + eps
219
- Igb_var = cv2.blur(Ig * Ib, (r, r)) - self._Ig_mean * self._Ib_mean
220
- Ibb_var = cv2.blur(Ib * Ib, (r, r)) - self._Ib_mean * self._Ib_mean + eps
221
-
222
- Irr_inv = Igg_var * Ibb_var - Igb_var * Igb_var
223
- Irg_inv = Igb_var * Irb_var - Irg_var * Ibb_var
224
- Irb_inv = Irg_var * Igb_var - Igg_var * Irb_var
225
- Igg_inv = Irr_var * Ibb_var - Irb_var * Irb_var
226
- Igb_inv = Irb_var * Irg_var - Irr_var * Igb_var
227
- Ibb_inv = Irr_var * Igg_var - Irg_var * Irg_var
228
-
229
- I_cov = Irr_inv * Irr_var + Irg_inv * Irg_var + Irb_inv * Irb_var
230
- Irr_inv /= I_cov
231
- Irg_inv /= I_cov
232
- Irb_inv /= I_cov
233
- Igg_inv /= I_cov
234
- Igb_inv /= I_cov
235
- Ibb_inv /= I_cov
236
-
237
- self._Irr_inv = Irr_inv
238
- self._Irg_inv = Irg_inv
239
- self._Irb_inv = Irb_inv
240
- self._Igg_inv = Igg_inv
241
- self._Igb_inv = Igb_inv
242
- self._Ibb_inv = Ibb_inv
243
-
244
- def _computeCoefficients(self, p):
245
- r = self._radius
246
- I = self._I
247
- Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
248
-
249
- p_mean = cv2.blur(p, (r, r))
250
-
251
- Ipr_mean = cv2.blur(Ir * p, (r, r))
252
- Ipg_mean = cv2.blur(Ig * p, (r, r))
253
- Ipb_mean = cv2.blur(Ib * p, (r, r))
254
-
255
- Ipr_cov = Ipr_mean - self._Ir_mean * p_mean
256
- Ipg_cov = Ipg_mean - self._Ig_mean * p_mean
257
- Ipb_cov = Ipb_mean - self._Ib_mean * p_mean
258
-
259
- ar = self._Irr_inv * Ipr_cov + self._Irg_inv * Ipg_cov + self._Irb_inv * Ipb_cov
260
- ag = self._Irg_inv * Ipr_cov + self._Igg_inv * Ipg_cov + self._Igb_inv * Ipb_cov
261
- ab = self._Irb_inv * Ipr_cov + self._Igb_inv * Ipg_cov + self._Ibb_inv * Ipb_cov
262
- b = p_mean - ar * self._Ir_mean - ag * self._Ig_mean - ab * self._Ib_mean
263
-
264
- ar_mean = cv2.blur(ar, (r, r))
265
- ag_mean = cv2.blur(ag, (r, r))
266
- ab_mean = cv2.blur(ab, (r, r))
267
- b_mean = cv2.blur(b, (r, r))
268
-
269
- return ar_mean, ag_mean, ab_mean, b_mean
270
-
271
- def _computeOutput(self, ab, I):
272
- ar_mean, ag_mean, ab_mean, b_mean = ab
273
-
274
- Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
275
-
276
- q = (ar_mean * Ir +
277
- ag_mean * Ig +
278
- ab_mean * Ib +
279
- b_mean)
280
-
281
  return q
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+ ## @package guided_filter.core.filters
4
+ #
5
+ # Implementation of guided filter.
6
+ # * GuidedFilter: Original guided filter.
7
+ # * FastGuidedFilter: Fast version of the guided filter.
8
+ # @author tody
9
+ # @date 2015/08/26
10
+
11
+ import numpy as np
12
+ import cv2
13
+
14
+ ## Convert image into float32 type.
15
+ def to32F(img):
16
+ if img.dtype == np.float32:
17
+ return img
18
+ return (1.0 / 255.0) * np.float32(img)
19
+
20
+ ## Convert image into uint8 type.
21
+ def to8U(img):
22
+ if img.dtype == np.uint8:
23
+ return img
24
+ return np.clip(np.uint8(255.0 * img), 0, 255)
25
+
26
+ ## Return if the input image is gray or not.
27
+ def _isGray(I):
28
+ return len(I.shape) == 2
29
+
30
+
31
+ ## Return down sampled image.
32
+ # @param scale (w/s, h/s) image will be created.
33
+ # @param shape I.shape[:2]=(h, w). numpy friendly size parameter.
34
+ def _downSample(I, scale=4, shape=None):
35
+ if shape is not None:
36
+ h, w = shape
37
+ return cv2.resize(I, (w, h), interpolation=cv2.INTER_NEAREST)
38
+
39
+ h, w = I.shape[:2]
40
+ return cv2.resize(I, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_NEAREST)
41
+
42
+
43
+ ## Return up sampled image.
44
+ # @param scale (w*s, h*s) image will be created.
45
+ # @param shape I.shape[:2]=(h, w). numpy friendly size parameter.
46
+ def _upSample(I, scale=2, shape=None):
47
+ if shape is not None:
48
+ h, w = shape
49
+ return cv2.resize(I, (w, h), interpolation=cv2.INTER_LINEAR)
50
+
51
+ h, w = I.shape[:2]
52
+ return cv2.resize(I, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR)
53
+
54
+ ## Fast guide filter.
55
+ class FastGuidedFilter:
56
+ ## Constructor.
57
+ # @param I Input guidance image. Color or gray.
58
+ # @param radius Radius of Guided Filter.
59
+ # @param epsilon Regularization term of Guided Filter.
60
+ # @param scale Down sampled scale.
61
+ def __init__(self, I, radius=5, epsilon=0.4, scale=4):
62
+ I_32F = to32F(I)
63
+ self._I = I_32F
64
+ h, w = I.shape[:2]
65
+
66
+ I_sub = _downSample(I_32F, scale)
67
+
68
+ self._I_sub = I_sub
69
+ radius = int(radius / scale)
70
+
71
+ if _isGray(I):
72
+ self._guided_filter = GuidedFilterGray(I_sub, radius, epsilon)
73
+ else:
74
+ self._guided_filter = GuidedFilterColor(I_sub, radius, epsilon)
75
+
76
+ ## Apply filter for the input image.
77
+ # @param p Input image for the filtering.
78
+ def filter(self, p):
79
+ p_32F = to32F(p)
80
+ shape_original = p.shape[:2]
81
+
82
+ p_sub = _downSample(p_32F, shape=self._I_sub.shape[:2])
83
+
84
+ if _isGray(p_sub):
85
+ return self._filterGray(p_sub, shape_original)
86
+
87
+ cs = p.shape[2]
88
+ q = np.array(p_32F)
89
+
90
+ for ci in range(cs):
91
+ q[:, :, ci] = self._filterGray(p_sub[:, :, ci], shape_original)
92
+ return to8U(q)
93
+
94
+ def _filterGray(self, p_sub, shape_original):
95
+ ab_sub = self._guided_filter._computeCoefficients(p_sub)
96
+ ab = [_upSample(abi, shape=shape_original) for abi in ab_sub]
97
+ return self._guided_filter._computeOutput(ab, self._I)
98
+
99
+
100
+ ## Guide filter.
101
+ class GuidedFilter:
102
+ ## Constructor.
103
+ # @param I Input guidance image. Color or gray.
104
+ # @param radius Radius of Guided Filter.
105
+ # @param epsilon Regularization term of Guided Filter.
106
+ def __init__(self, I, radius=5, epsilon=0.4):
107
+ I_32F = to32F(I)
108
+
109
+ if _isGray(I):
110
+ self._guided_filter = GuidedFilterGray(I_32F, radius, epsilon)
111
+ else:
112
+ self._guided_filter = GuidedFilterColor(I_32F, radius, epsilon)
113
+
114
+ ## Apply filter for the input image.
115
+ # @param p Input image for the filtering.
116
+ def filter(self, p):
117
+ return to8U(self._guided_filter.filter(p))
118
+
119
+
120
+ ## Common parts of guided filter.
121
+ #
122
+ # This class is used by guided_filter class. GuidedFilterGray and GuidedFilterColor.
123
+ # Based on guided_filter._computeCoefficients, guided_filter._computeOutput,
124
+ # GuidedFilterCommon.filter computes filtered image for color and gray.
125
+ class GuidedFilterCommon:
126
+ def __init__(self, guided_filter):
127
+ self._guided_filter = guided_filter
128
+
129
+ ## Apply filter for the input image.
130
+ # @param p Input image for the filtering.
131
+ def filter(self, p):
132
+ p_32F = to32F(p)
133
+ if _isGray(p_32F):
134
+ return self._filterGray(p_32F)
135
+
136
+ cs = p.shape[2]
137
+ q = np.array(p_32F)
138
+
139
+ for ci in range(cs):
140
+ q[:, :, ci] = self._filterGray(p_32F[:, :, ci])
141
+ return q
142
+
143
+ def _filterGray(self, p):
144
+ ab = self._guided_filter._computeCoefficients(p)
145
+ return self._guided_filter._computeOutput(ab, self._guided_filter._I)
146
+
147
+
148
+ ## Guided filter for gray guidance image.
149
+ class GuidedFilterGray:
150
+ # @param I Input gray guidance image.
151
+ # @param radius Radius of Guided Filter.
152
+ # @param epsilon Regularization term of Guided Filter.
153
+ def __init__(self, I, radius=5, epsilon=0.4):
154
+ self._radius = 2 * radius + 1
155
+ self._epsilon = epsilon
156
+ self._I = to32F(I)
157
+ self._initFilter()
158
+ self._filter_common = GuidedFilterCommon(self)
159
+
160
+ ## Apply filter for the input image.
161
+ # @param p Input image for the filtering.
162
+ def filter(self, p):
163
+ return self._filter_common.filter(p)
164
+
165
+ def _initFilter(self):
166
+ I = self._I
167
+ r = self._radius
168
+ self._I_mean = cv2.blur(I, (r, r))
169
+ I_mean_sq = cv2.blur(I ** 2, (r, r))
170
+ self._I_var = I_mean_sq - self._I_mean ** 2
171
+
172
+ def _computeCoefficients(self, p):
173
+ r = self._radius
174
+ p_mean = cv2.blur(p, (r, r))
175
+ p_cov = p_mean - self._I_mean * p_mean
176
+ a = p_cov / (self._I_var + self._epsilon)
177
+ b = p_mean - a * self._I_mean
178
+ a_mean = cv2.blur(a, (r, r))
179
+ b_mean = cv2.blur(b, (r, r))
180
+ return a_mean, b_mean
181
+
182
+ def _computeOutput(self, ab, I):
183
+ a_mean, b_mean = ab
184
+ return a_mean * I + b_mean
185
+
186
+
187
+ ## Guided filter for color guidance image.
188
+ class GuidedFilterColor:
189
+ # @param I Input color guidance image.
190
+ # @param radius Radius of Guided Filter.
191
+ # @param epsilon Regularization term of Guided Filter.
192
+ def __init__(self, I, radius=5, epsilon=0.2):
193
+ self._radius = 2 * radius + 1
194
+ self._epsilon = epsilon
195
+ self._I = to32F(I)
196
+ self._initFilter()
197
+ self._filter_common = GuidedFilterCommon(self)
198
+
199
+ ## Apply filter for the input image.
200
+ # @param p Input image for the filtering.
201
+ def filter(self, p):
202
+ return self._filter_common.filter(p)
203
+
204
+ def _initFilter(self):
205
+ I = self._I
206
+ r = self._radius
207
+ eps = self._epsilon
208
+
209
+ Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
210
+
211
+ self._Ir_mean = cv2.blur(Ir, (r, r))
212
+ self._Ig_mean = cv2.blur(Ig, (r, r))
213
+ self._Ib_mean = cv2.blur(Ib, (r, r))
214
+
215
+ Irr_var = cv2.blur(Ir ** 2, (r, r)) - self._Ir_mean ** 2 + eps
216
+ Irg_var = cv2.blur(Ir * Ig, (r, r)) - self._Ir_mean * self._Ig_mean
217
+ Irb_var = cv2.blur(Ir * Ib, (r, r)) - self._Ir_mean * self._Ib_mean
218
+ Igg_var = cv2.blur(Ig * Ig, (r, r)) - self._Ig_mean * self._Ig_mean + eps
219
+ Igb_var = cv2.blur(Ig * Ib, (r, r)) - self._Ig_mean * self._Ib_mean
220
+ Ibb_var = cv2.blur(Ib * Ib, (r, r)) - self._Ib_mean * self._Ib_mean + eps
221
+
222
+ Irr_inv = Igg_var * Ibb_var - Igb_var * Igb_var
223
+ Irg_inv = Igb_var * Irb_var - Irg_var * Ibb_var
224
+ Irb_inv = Irg_var * Igb_var - Igg_var * Irb_var
225
+ Igg_inv = Irr_var * Ibb_var - Irb_var * Irb_var
226
+ Igb_inv = Irb_var * Irg_var - Irr_var * Igb_var
227
+ Ibb_inv = Irr_var * Igg_var - Irg_var * Irg_var
228
+
229
+ I_cov = Irr_inv * Irr_var + Irg_inv * Irg_var + Irb_inv * Irb_var
230
+ Irr_inv /= I_cov
231
+ Irg_inv /= I_cov
232
+ Irb_inv /= I_cov
233
+ Igg_inv /= I_cov
234
+ Igb_inv /= I_cov
235
+ Ibb_inv /= I_cov
236
+
237
+ self._Irr_inv = Irr_inv
238
+ self._Irg_inv = Irg_inv
239
+ self._Irb_inv = Irb_inv
240
+ self._Igg_inv = Igg_inv
241
+ self._Igb_inv = Igb_inv
242
+ self._Ibb_inv = Ibb_inv
243
+
244
+ def _computeCoefficients(self, p):
245
+ r = self._radius
246
+ I = self._I
247
+ Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
248
+
249
+ p_mean = cv2.blur(p, (r, r))
250
+
251
+ Ipr_mean = cv2.blur(Ir * p, (r, r))
252
+ Ipg_mean = cv2.blur(Ig * p, (r, r))
253
+ Ipb_mean = cv2.blur(Ib * p, (r, r))
254
+
255
+ Ipr_cov = Ipr_mean - self._Ir_mean * p_mean
256
+ Ipg_cov = Ipg_mean - self._Ig_mean * p_mean
257
+ Ipb_cov = Ipb_mean - self._Ib_mean * p_mean
258
+
259
+ ar = self._Irr_inv * Ipr_cov + self._Irg_inv * Ipg_cov + self._Irb_inv * Ipb_cov
260
+ ag = self._Irg_inv * Ipr_cov + self._Igg_inv * Ipg_cov + self._Igb_inv * Ipb_cov
261
+ ab = self._Irb_inv * Ipr_cov + self._Igb_inv * Ipg_cov + self._Ibb_inv * Ipb_cov
262
+ b = p_mean - ar * self._Ir_mean - ag * self._Ig_mean - ab * self._Ib_mean
263
+
264
+ ar_mean = cv2.blur(ar, (r, r))
265
+ ag_mean = cv2.blur(ag, (r, r))
266
+ ab_mean = cv2.blur(ab, (r, r))
267
+ b_mean = cv2.blur(b, (r, r))
268
+
269
+ return ar_mean, ag_mean, ab_mean, b_mean
270
+
271
+ def _computeOutput(self, ab, I):
272
+ ar_mean, ag_mean, ab_mean, b_mean = ab
273
+
274
+ Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
275
+
276
+ q = (ar_mean * Ir +
277
+ ag_mean * Ig +
278
+ ab_mean * Ib +
279
+ b_mean)
280
+
281
  return q
t2i/controlnet_union/mask.py CHANGED
@@ -1,347 +1,347 @@
1
- import math
2
- import random
3
- import hashlib
4
- import logging
5
- from enum import Enum
6
-
7
- import cv2
8
- import numpy as np
9
-
10
- # from saicinpainting.evaluation.masks.mask import SegmentationMask
11
- # from saicinpainting.utils import LinearRamp
12
-
13
- LOGGER = logging.getLogger(__name__)
14
-
15
- class LinearRamp:
16
- def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0):
17
- self.start_value = start_value
18
- self.end_value = end_value
19
- self.start_iter = start_iter
20
- self.end_iter = end_iter
21
-
22
- def __call__(self, i):
23
- if i < self.start_iter:
24
- return self.start_value
25
- if i >= self.end_iter:
26
- return self.end_value
27
- part = (i - self.start_iter) / (self.end_iter - self.start_iter)
28
- return self.start_value * (1 - part) + self.end_value * part
29
-
30
-
31
- class DrawMethod(Enum):
32
- LINE = 'line'
33
- CIRCLE = 'circle'
34
- SQUARE = 'square'
35
-
36
-
37
- def make_random_irregular_mask(shape, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10,
38
- draw_method=DrawMethod.LINE):
39
- draw_method = DrawMethod(draw_method)
40
-
41
- height, width = shape
42
- mask = np.zeros((height, width), np.float32)
43
- times = np.random.randint(min_times, max_times + 1)
44
- for i in range(times):
45
- start_x = np.random.randint(width)
46
- start_y = np.random.randint(height)
47
- for j in range(1 + np.random.randint(5)):
48
- angle = 0.01 + np.random.randint(max_angle)
49
- if i % 2 == 0:
50
- angle = 2 * 3.1415926 - angle
51
- length = 10 + np.random.randint(max_len)
52
- brush_w = 5 + np.random.randint(max_width)
53
- end_x = np.clip((start_x + length * np.sin(angle)).astype(np.int32), 0, width)
54
- end_y = np.clip((start_y + length * np.cos(angle)).astype(np.int32), 0, height)
55
- if draw_method == DrawMethod.LINE:
56
- cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w)
57
- elif draw_method == DrawMethod.CIRCLE:
58
- cv2.circle(mask, (start_x, start_y), radius=brush_w, color=1., thickness=-1)
59
- elif draw_method == DrawMethod.SQUARE:
60
- radius = brush_w // 2
61
- mask[start_y - radius:start_y + radius, start_x - radius:start_x + radius] = 1
62
- start_x, start_y = end_x, end_y
63
- return mask[None, ...]
64
-
65
-
66
- class RandomIrregularMaskGenerator:
67
- def __init__(self, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10, ramp_kwargs=None,
68
- draw_method=DrawMethod.LINE):
69
- self.max_angle = max_angle
70
- self.max_len = max_len
71
- self.max_width = max_width
72
- self.min_times = min_times
73
- self.max_times = max_times
74
- self.draw_method = draw_method
75
- self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
76
-
77
- def __call__(self, img, iter_i=None, raw_image=None):
78
- coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
79
- cur_max_len = int(max(1, self.max_len * coef))
80
- cur_max_width = int(max(1, self.max_width * coef))
81
- cur_max_times = int(self.min_times + 1 + (self.max_times - self.min_times) * coef)
82
- return make_random_irregular_mask(img.shape[1:], max_angle=self.max_angle, max_len=cur_max_len,
83
- max_width=cur_max_width, min_times=self.min_times, max_times=cur_max_times,
84
- draw_method=self.draw_method)
85
-
86
-
87
- def make_random_rectangle_mask(shape, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3):
88
- height, width = shape
89
- mask = np.zeros((height, width), np.float32)
90
- bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2)
91
- times = np.random.randint(min_times, max_times + 1)
92
- for i in range(times):
93
- box_width = np.random.randint(bbox_min_size, bbox_max_size)
94
- box_height = np.random.randint(bbox_min_size, bbox_max_size)
95
- start_x = np.random.randint(margin, width - margin - box_width + 1)
96
- start_y = np.random.randint(margin, height - margin - box_height + 1)
97
- mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1
98
- return mask[None, ...]
99
-
100
-
101
- class RandomRectangleMaskGenerator:
102
- def __init__(self, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3, ramp_kwargs=None):
103
- self.margin = margin
104
- self.bbox_min_size = bbox_min_size
105
- self.bbox_max_size = bbox_max_size
106
- self.min_times = min_times
107
- self.max_times = max_times
108
- self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
109
-
110
- def __call__(self, img, iter_i=None, raw_image=None):
111
- coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
112
- cur_bbox_max_size = int(self.bbox_min_size + 1 + (self.bbox_max_size - self.bbox_min_size) * coef)
113
- cur_max_times = int(self.min_times + (self.max_times - self.min_times) * coef)
114
- return make_random_rectangle_mask(img.shape[1:], margin=self.margin, bbox_min_size=self.bbox_min_size,
115
- bbox_max_size=cur_bbox_max_size, min_times=self.min_times,
116
- max_times=cur_max_times)
117
-
118
-
119
- # class RandomSegmentationMaskGenerator:
120
- # def __init__(self, **kwargs):
121
- # self.impl = None # will be instantiated in first call (effectively in subprocess)
122
- # self.kwargs = kwargs
123
-
124
- # def __call__(self, img, iter_i=None, raw_image=None):
125
- # if self.impl is None:
126
- # self.impl = SegmentationMask(**self.kwargs)
127
-
128
- # masks = self.impl.get_masks(np.transpose(img, (1, 2, 0)))
129
- # masks = [m for m in masks if len(np.unique(m)) > 1]
130
- # return np.random.choice(masks)
131
-
132
-
133
- def make_random_superres_mask(shape, min_step=2, max_step=4, min_width=1, max_width=3):
134
- height, width = shape
135
- mask = np.zeros((height, width), np.float32)
136
- step_x = np.random.randint(min_step, max_step + 1)
137
- width_x = np.random.randint(min_width, min(step_x, max_width + 1))
138
- offset_x = np.random.randint(0, step_x)
139
-
140
- step_y = np.random.randint(min_step, max_step + 1)
141
- width_y = np.random.randint(min_width, min(step_y, max_width + 1))
142
- offset_y = np.random.randint(0, step_y)
143
-
144
- for dy in range(width_y):
145
- mask[offset_y + dy::step_y] = 1
146
- for dx in range(width_x):
147
- mask[:, offset_x + dx::step_x] = 1
148
- return mask[None, ...]
149
-
150
-
151
- class RandomSuperresMaskGenerator:
152
- def __init__(self, **kwargs):
153
- self.kwargs = kwargs
154
-
155
- def __call__(self, img, iter_i=None):
156
- return make_random_superres_mask(img.shape[1:], **self.kwargs)
157
-
158
-
159
- class DumbAreaMaskGenerator:
160
- min_ratio = 0.1
161
- max_ratio = 0.35
162
- default_ratio = 0.225
163
-
164
- def __init__(self, is_training):
165
- #Parameters:
166
- # is_training(bool): If true - random rectangular mask, if false - central square mask
167
- self.is_training = is_training
168
-
169
- def _random_vector(self, dimension):
170
- if self.is_training:
171
- lower_limit = math.sqrt(self.min_ratio)
172
- upper_limit = math.sqrt(self.max_ratio)
173
- mask_side = round((random.random() * (upper_limit - lower_limit) + lower_limit) * dimension)
174
- u = random.randint(0, dimension-mask_side-1)
175
- v = u+mask_side
176
- else:
177
- margin = (math.sqrt(self.default_ratio) / 2) * dimension
178
- u = round(dimension/2 - margin)
179
- v = round(dimension/2 + margin)
180
- return u, v
181
-
182
- def __call__(self, img, iter_i=None, raw_image=None):
183
- c, height, width = img.shape
184
- mask = np.zeros((height, width), np.float32)
185
- x1, x2 = self._random_vector(width)
186
- y1, y2 = self._random_vector(height)
187
- mask[x1:x2, y1:y2] = 1
188
- return mask[None, ...]
189
-
190
-
191
- class OutpaintingMaskGenerator:
192
- def __init__(self, min_padding_percent:float=0.04, max_padding_percent:int=0.25, left_padding_prob:float=0.5, top_padding_prob:float=0.5,
193
- right_padding_prob:float=0.5, bottom_padding_prob:float=0.5, is_fixed_randomness:bool=False):
194
- """
195
- is_fixed_randomness - get identical paddings for the same image if args are the same
196
- """
197
- self.min_padding_percent = min_padding_percent
198
- self.max_padding_percent = max_padding_percent
199
- self.probs = [left_padding_prob, top_padding_prob, right_padding_prob, bottom_padding_prob]
200
- self.is_fixed_randomness = is_fixed_randomness
201
-
202
- assert self.min_padding_percent <= self.max_padding_percent
203
- assert self.max_padding_percent > 0
204
- assert len([x for x in [self.min_padding_percent, self.max_padding_percent] if (x>=0 and x<=1)]) == 2, f"Padding percentage should be in [0,1]"
205
- assert sum(self.probs) > 0, f"At least one of the padding probs should be greater than 0 - {self.probs}"
206
- assert len([x for x in self.probs if (x >= 0) and (x <= 1)]) == 4, f"At least one of padding probs is not in [0,1] - {self.probs}"
207
- if len([x for x in self.probs if x > 0]) == 1:
208
- LOGGER.warning(f"Only one padding prob is greater than zero - {self.probs}. That means that the outpainting masks will be always on the same side")
209
-
210
- def apply_padding(self, mask, coord):
211
- mask[int(coord[0][0]*self.img_h):int(coord[1][0]*self.img_h),
212
- int(coord[0][1]*self.img_w):int(coord[1][1]*self.img_w)] = 1
213
- return mask
214
-
215
- def get_padding(self, size):
216
- n1 = int(self.min_padding_percent*size)
217
- n2 = int(self.max_padding_percent*size)
218
- return self.rnd.randint(n1, n2) / size
219
-
220
- @staticmethod
221
- def _img2rs(img):
222
- arr = np.ascontiguousarray(img.astype(np.uint8))
223
- str_hash = hashlib.sha1(arr).hexdigest()
224
- res = hash(str_hash)%(2**32)
225
- return res
226
-
227
- def __call__(self, img, iter_i=None, raw_image=None):
228
- c, self.img_h, self.img_w = img.shape
229
- mask = np.zeros((self.img_h, self.img_w), np.float32)
230
- at_least_one_mask_applied = False
231
-
232
- if self.is_fixed_randomness:
233
- assert raw_image is not None, f"Cant calculate hash on raw_image=None"
234
- rs = self._img2rs(raw_image)
235
- self.rnd = np.random.RandomState(rs)
236
- else:
237
- self.rnd = np.random
238
-
239
- coords = [[
240
- (0,0),
241
- (1,self.get_padding(size=self.img_h))
242
- ],
243
- [
244
- (0,0),
245
- (self.get_padding(size=self.img_w),1)
246
- ],
247
- [
248
- (0,1-self.get_padding(size=self.img_h)),
249
- (1,1)
250
- ],
251
- [
252
- (1-self.get_padding(size=self.img_w),0),
253
- (1,1)
254
- ]]
255
-
256
- for pp, coord in zip(self.probs, coords):
257
- if self.rnd.random() < pp:
258
- at_least_one_mask_applied = True
259
- mask = self.apply_padding(mask=mask, coord=coord)
260
-
261
- if not at_least_one_mask_applied:
262
- idx = self.rnd.choice(range(len(coords)), p=np.array(self.probs)/sum(self.probs))
263
- mask = self.apply_padding(mask=mask, coord=coords[idx])
264
- return mask[None, ...]
265
-
266
-
267
- class MixedMaskGenerator:
268
- def __init__(self, irregular_proba=1/3, irregular_kwargs=None,
269
- box_proba=1/3, box_kwargs=None,
270
- segm_proba=1/3, segm_kwargs=None,
271
- squares_proba=0, squares_kwargs=None,
272
- superres_proba=0, superres_kwargs=None,
273
- outpainting_proba=0, outpainting_kwargs=None,
274
- invert_proba=0):
275
- self.probas = []
276
- self.gens = []
277
-
278
- if irregular_proba > 0:
279
- self.probas.append(irregular_proba)
280
- if irregular_kwargs is None:
281
- irregular_kwargs = {}
282
- else:
283
- irregular_kwargs = dict(irregular_kwargs)
284
- irregular_kwargs['draw_method'] = DrawMethod.LINE
285
- self.gens.append(RandomIrregularMaskGenerator(**irregular_kwargs))
286
-
287
- if box_proba > 0:
288
- self.probas.append(box_proba)
289
- if box_kwargs is None:
290
- box_kwargs = {}
291
- self.gens.append(RandomRectangleMaskGenerator(**box_kwargs))
292
-
293
- # if segm_proba > 0:
294
- # self.probas.append(segm_proba)
295
- # if segm_kwargs is None:
296
- # segm_kwargs = {}
297
- # self.gens.append(RandomSegmentationMaskGenerator(**segm_kwargs))
298
-
299
- if squares_proba > 0:
300
- self.probas.append(squares_proba)
301
- if squares_kwargs is None:
302
- squares_kwargs = {}
303
- else:
304
- squares_kwargs = dict(squares_kwargs)
305
- squares_kwargs['draw_method'] = DrawMethod.SQUARE
306
- self.gens.append(RandomIrregularMaskGenerator(**squares_kwargs))
307
-
308
- if superres_proba > 0:
309
- self.probas.append(superres_proba)
310
- if superres_kwargs is None:
311
- superres_kwargs = {}
312
- self.gens.append(RandomSuperresMaskGenerator(**superres_kwargs))
313
-
314
- if outpainting_proba > 0:
315
- self.probas.append(outpainting_proba)
316
- if outpainting_kwargs is None:
317
- outpainting_kwargs = {}
318
- self.gens.append(OutpaintingMaskGenerator(**outpainting_kwargs))
319
-
320
- self.probas = np.array(self.probas, dtype='float32')
321
- self.probas /= self.probas.sum()
322
- self.invert_proba = invert_proba
323
-
324
- def __call__(self, img, iter_i=None, raw_image=None):
325
- kind = np.random.choice(len(self.probas), p=self.probas)
326
- gen = self.gens[kind]
327
- result = gen(img, iter_i=iter_i, raw_image=raw_image)
328
- if self.invert_proba > 0 and random.random() < self.invert_proba:
329
- result = 1 - result
330
- return result
331
-
332
-
333
- def get_mask_generator(kind, kwargs):
334
- if kind is None:
335
- kind = "mixed"
336
- if kwargs is None:
337
- kwargs = {}
338
-
339
- if kind == "mixed":
340
- cl = MixedMaskGenerator
341
- elif kind == "outpainting":
342
- cl = OutpaintingMaskGenerator
343
- elif kind == "dumb":
344
- cl = DumbAreaMaskGenerator
345
- else:
346
- raise NotImplementedError(f"No such generator kind = {kind}")
347
- return cl(**kwargs)
 
1
+ import math
2
+ import random
3
+ import hashlib
4
+ import logging
5
+ from enum import Enum
6
+
7
+ import cv2
8
+ import numpy as np
9
+
10
+ # from saicinpainting.evaluation.masks.mask import SegmentationMask
11
+ # from saicinpainting.utils import LinearRamp
12
+
13
+ LOGGER = logging.getLogger(__name__)
14
+
15
+ class LinearRamp:
16
+ def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0):
17
+ self.start_value = start_value
18
+ self.end_value = end_value
19
+ self.start_iter = start_iter
20
+ self.end_iter = end_iter
21
+
22
+ def __call__(self, i):
23
+ if i < self.start_iter:
24
+ return self.start_value
25
+ if i >= self.end_iter:
26
+ return self.end_value
27
+ part = (i - self.start_iter) / (self.end_iter - self.start_iter)
28
+ return self.start_value * (1 - part) + self.end_value * part
29
+
30
+
31
+ class DrawMethod(Enum):
32
+ LINE = 'line'
33
+ CIRCLE = 'circle'
34
+ SQUARE = 'square'
35
+
36
+
37
+ def make_random_irregular_mask(shape, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10,
38
+ draw_method=DrawMethod.LINE):
39
+ draw_method = DrawMethod(draw_method)
40
+
41
+ height, width = shape
42
+ mask = np.zeros((height, width), np.float32)
43
+ times = np.random.randint(min_times, max_times + 1)
44
+ for i in range(times):
45
+ start_x = np.random.randint(width)
46
+ start_y = np.random.randint(height)
47
+ for j in range(1 + np.random.randint(5)):
48
+ angle = 0.01 + np.random.randint(max_angle)
49
+ if i % 2 == 0:
50
+ angle = 2 * 3.1415926 - angle
51
+ length = 10 + np.random.randint(max_len)
52
+ brush_w = 5 + np.random.randint(max_width)
53
+ end_x = np.clip((start_x + length * np.sin(angle)).astype(np.int32), 0, width)
54
+ end_y = np.clip((start_y + length * np.cos(angle)).astype(np.int32), 0, height)
55
+ if draw_method == DrawMethod.LINE:
56
+ cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w)
57
+ elif draw_method == DrawMethod.CIRCLE:
58
+ cv2.circle(mask, (start_x, start_y), radius=brush_w, color=1., thickness=-1)
59
+ elif draw_method == DrawMethod.SQUARE:
60
+ radius = brush_w // 2
61
+ mask[start_y - radius:start_y + radius, start_x - radius:start_x + radius] = 1
62
+ start_x, start_y = end_x, end_y
63
+ return mask[None, ...]
64
+
65
+
66
+ class RandomIrregularMaskGenerator:
67
+ def __init__(self, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10, ramp_kwargs=None,
68
+ draw_method=DrawMethod.LINE):
69
+ self.max_angle = max_angle
70
+ self.max_len = max_len
71
+ self.max_width = max_width
72
+ self.min_times = min_times
73
+ self.max_times = max_times
74
+ self.draw_method = draw_method
75
+ self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
76
+
77
+ def __call__(self, img, iter_i=None, raw_image=None):
78
+ coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
79
+ cur_max_len = int(max(1, self.max_len * coef))
80
+ cur_max_width = int(max(1, self.max_width * coef))
81
+ cur_max_times = int(self.min_times + 1 + (self.max_times - self.min_times) * coef)
82
+ return make_random_irregular_mask(img.shape[1:], max_angle=self.max_angle, max_len=cur_max_len,
83
+ max_width=cur_max_width, min_times=self.min_times, max_times=cur_max_times,
84
+ draw_method=self.draw_method)
85
+
86
+
87
+ def make_random_rectangle_mask(shape, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3):
88
+ height, width = shape
89
+ mask = np.zeros((height, width), np.float32)
90
+ bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2)
91
+ times = np.random.randint(min_times, max_times + 1)
92
+ for i in range(times):
93
+ box_width = np.random.randint(bbox_min_size, bbox_max_size)
94
+ box_height = np.random.randint(bbox_min_size, bbox_max_size)
95
+ start_x = np.random.randint(margin, width - margin - box_width + 1)
96
+ start_y = np.random.randint(margin, height - margin - box_height + 1)
97
+ mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1
98
+ return mask[None, ...]
99
+
100
+
101
+ class RandomRectangleMaskGenerator:
102
+ def __init__(self, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3, ramp_kwargs=None):
103
+ self.margin = margin
104
+ self.bbox_min_size = bbox_min_size
105
+ self.bbox_max_size = bbox_max_size
106
+ self.min_times = min_times
107
+ self.max_times = max_times
108
+ self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
109
+
110
+ def __call__(self, img, iter_i=None, raw_image=None):
111
+ coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
112
+ cur_bbox_max_size = int(self.bbox_min_size + 1 + (self.bbox_max_size - self.bbox_min_size) * coef)
113
+ cur_max_times = int(self.min_times + (self.max_times - self.min_times) * coef)
114
+ return make_random_rectangle_mask(img.shape[1:], margin=self.margin, bbox_min_size=self.bbox_min_size,
115
+ bbox_max_size=cur_bbox_max_size, min_times=self.min_times,
116
+ max_times=cur_max_times)
117
+
118
+
119
+ # class RandomSegmentationMaskGenerator:
120
+ # def __init__(self, **kwargs):
121
+ # self.impl = None # will be instantiated in first call (effectively in subprocess)
122
+ # self.kwargs = kwargs
123
+
124
+ # def __call__(self, img, iter_i=None, raw_image=None):
125
+ # if self.impl is None:
126
+ # self.impl = SegmentationMask(**self.kwargs)
127
+
128
+ # masks = self.impl.get_masks(np.transpose(img, (1, 2, 0)))
129
+ # masks = [m for m in masks if len(np.unique(m)) > 1]
130
+ # return np.random.choice(masks)
131
+
132
+
133
+ def make_random_superres_mask(shape, min_step=2, max_step=4, min_width=1, max_width=3):
134
+ height, width = shape
135
+ mask = np.zeros((height, width), np.float32)
136
+ step_x = np.random.randint(min_step, max_step + 1)
137
+ width_x = np.random.randint(min_width, min(step_x, max_width + 1))
138
+ offset_x = np.random.randint(0, step_x)
139
+
140
+ step_y = np.random.randint(min_step, max_step + 1)
141
+ width_y = np.random.randint(min_width, min(step_y, max_width + 1))
142
+ offset_y = np.random.randint(0, step_y)
143
+
144
+ for dy in range(width_y):
145
+ mask[offset_y + dy::step_y] = 1
146
+ for dx in range(width_x):
147
+ mask[:, offset_x + dx::step_x] = 1
148
+ return mask[None, ...]
149
+
150
+
151
+ class RandomSuperresMaskGenerator:
152
+ def __init__(self, **kwargs):
153
+ self.kwargs = kwargs
154
+
155
+ def __call__(self, img, iter_i=None):
156
+ return make_random_superres_mask(img.shape[1:], **self.kwargs)
157
+
158
+
159
+ class DumbAreaMaskGenerator:
160
+ min_ratio = 0.1
161
+ max_ratio = 0.35
162
+ default_ratio = 0.225
163
+
164
+ def __init__(self, is_training):
165
+ #Parameters:
166
+ # is_training(bool): If true - random rectangular mask, if false - central square mask
167
+ self.is_training = is_training
168
+
169
+ def _random_vector(self, dimension):
170
+ if self.is_training:
171
+ lower_limit = math.sqrt(self.min_ratio)
172
+ upper_limit = math.sqrt(self.max_ratio)
173
+ mask_side = round((random.random() * (upper_limit - lower_limit) + lower_limit) * dimension)
174
+ u = random.randint(0, dimension-mask_side-1)
175
+ v = u+mask_side
176
+ else:
177
+ margin = (math.sqrt(self.default_ratio) / 2) * dimension
178
+ u = round(dimension/2 - margin)
179
+ v = round(dimension/2 + margin)
180
+ return u, v
181
+
182
+ def __call__(self, img, iter_i=None, raw_image=None):
183
+ c, height, width = img.shape
184
+ mask = np.zeros((height, width), np.float32)
185
+ x1, x2 = self._random_vector(width)
186
+ y1, y2 = self._random_vector(height)
187
+ mask[x1:x2, y1:y2] = 1
188
+ return mask[None, ...]
189
+
190
+
191
+ class OutpaintingMaskGenerator:
192
+ def __init__(self, min_padding_percent:float=0.04, max_padding_percent:int=0.25, left_padding_prob:float=0.5, top_padding_prob:float=0.5,
193
+ right_padding_prob:float=0.5, bottom_padding_prob:float=0.5, is_fixed_randomness:bool=False):
194
+ """
195
+ is_fixed_randomness - get identical paddings for the same image if args are the same
196
+ """
197
+ self.min_padding_percent = min_padding_percent
198
+ self.max_padding_percent = max_padding_percent
199
+ self.probs = [left_padding_prob, top_padding_prob, right_padding_prob, bottom_padding_prob]
200
+ self.is_fixed_randomness = is_fixed_randomness
201
+
202
+ assert self.min_padding_percent <= self.max_padding_percent
203
+ assert self.max_padding_percent > 0
204
+ assert len([x for x in [self.min_padding_percent, self.max_padding_percent] if (x>=0 and x<=1)]) == 2, f"Padding percentage should be in [0,1]"
205
+ assert sum(self.probs) > 0, f"At least one of the padding probs should be greater than 0 - {self.probs}"
206
+ assert len([x for x in self.probs if (x >= 0) and (x <= 1)]) == 4, f"At least one of padding probs is not in [0,1] - {self.probs}"
207
+ if len([x for x in self.probs if x > 0]) == 1:
208
+ LOGGER.warning(f"Only one padding prob is greater than zero - {self.probs}. That means that the outpainting masks will be always on the same side")
209
+
210
+ def apply_padding(self, mask, coord):
211
+ mask[int(coord[0][0]*self.img_h):int(coord[1][0]*self.img_h),
212
+ int(coord[0][1]*self.img_w):int(coord[1][1]*self.img_w)] = 1
213
+ return mask
214
+
215
+ def get_padding(self, size):
216
+ n1 = int(self.min_padding_percent*size)
217
+ n2 = int(self.max_padding_percent*size)
218
+ return self.rnd.randint(n1, n2) / size
219
+
220
+ @staticmethod
221
+ def _img2rs(img):
222
+ arr = np.ascontiguousarray(img.astype(np.uint8))
223
+ str_hash = hashlib.sha1(arr).hexdigest()
224
+ res = hash(str_hash)%(2**32)
225
+ return res
226
+
227
+ def __call__(self, img, iter_i=None, raw_image=None):
228
+ c, self.img_h, self.img_w = img.shape
229
+ mask = np.zeros((self.img_h, self.img_w), np.float32)
230
+ at_least_one_mask_applied = False
231
+
232
+ if self.is_fixed_randomness:
233
+ assert raw_image is not None, f"Cant calculate hash on raw_image=None"
234
+ rs = self._img2rs(raw_image)
235
+ self.rnd = np.random.RandomState(rs)
236
+ else:
237
+ self.rnd = np.random
238
+
239
+ coords = [[
240
+ (0,0),
241
+ (1,self.get_padding(size=self.img_h))
242
+ ],
243
+ [
244
+ (0,0),
245
+ (self.get_padding(size=self.img_w),1)
246
+ ],
247
+ [
248
+ (0,1-self.get_padding(size=self.img_h)),
249
+ (1,1)
250
+ ],
251
+ [
252
+ (1-self.get_padding(size=self.img_w),0),
253
+ (1,1)
254
+ ]]
255
+
256
+ for pp, coord in zip(self.probs, coords):
257
+ if self.rnd.random() < pp:
258
+ at_least_one_mask_applied = True
259
+ mask = self.apply_padding(mask=mask, coord=coord)
260
+
261
+ if not at_least_one_mask_applied:
262
+ idx = self.rnd.choice(range(len(coords)), p=np.array(self.probs)/sum(self.probs))
263
+ mask = self.apply_padding(mask=mask, coord=coords[idx])
264
+ return mask[None, ...]
265
+
266
+
267
+ class MixedMaskGenerator:
268
+ def __init__(self, irregular_proba=1/3, irregular_kwargs=None,
269
+ box_proba=1/3, box_kwargs=None,
270
+ segm_proba=1/3, segm_kwargs=None,
271
+ squares_proba=0, squares_kwargs=None,
272
+ superres_proba=0, superres_kwargs=None,
273
+ outpainting_proba=0, outpainting_kwargs=None,
274
+ invert_proba=0):
275
+ self.probas = []
276
+ self.gens = []
277
+
278
+ if irregular_proba > 0:
279
+ self.probas.append(irregular_proba)
280
+ if irregular_kwargs is None:
281
+ irregular_kwargs = {}
282
+ else:
283
+ irregular_kwargs = dict(irregular_kwargs)
284
+ irregular_kwargs['draw_method'] = DrawMethod.LINE
285
+ self.gens.append(RandomIrregularMaskGenerator(**irregular_kwargs))
286
+
287
+ if box_proba > 0:
288
+ self.probas.append(box_proba)
289
+ if box_kwargs is None:
290
+ box_kwargs = {}
291
+ self.gens.append(RandomRectangleMaskGenerator(**box_kwargs))
292
+
293
+ # if segm_proba > 0:
294
+ # self.probas.append(segm_proba)
295
+ # if segm_kwargs is None:
296
+ # segm_kwargs = {}
297
+ # self.gens.append(RandomSegmentationMaskGenerator(**segm_kwargs))
298
+
299
+ if squares_proba > 0:
300
+ self.probas.append(squares_proba)
301
+ if squares_kwargs is None:
302
+ squares_kwargs = {}
303
+ else:
304
+ squares_kwargs = dict(squares_kwargs)
305
+ squares_kwargs['draw_method'] = DrawMethod.SQUARE
306
+ self.gens.append(RandomIrregularMaskGenerator(**squares_kwargs))
307
+
308
+ if superres_proba > 0:
309
+ self.probas.append(superres_proba)
310
+ if superres_kwargs is None:
311
+ superres_kwargs = {}
312
+ self.gens.append(RandomSuperresMaskGenerator(**superres_kwargs))
313
+
314
+ if outpainting_proba > 0:
315
+ self.probas.append(outpainting_proba)
316
+ if outpainting_kwargs is None:
317
+ outpainting_kwargs = {}
318
+ self.gens.append(OutpaintingMaskGenerator(**outpainting_kwargs))
319
+
320
+ self.probas = np.array(self.probas, dtype='float32')
321
+ self.probas /= self.probas.sum()
322
+ self.invert_proba = invert_proba
323
+
324
+ def __call__(self, img, iter_i=None, raw_image=None):
325
+ kind = np.random.choice(len(self.probas), p=self.probas)
326
+ gen = self.gens[kind]
327
+ result = gen(img, iter_i=iter_i, raw_image=raw_image)
328
+ if self.invert_proba > 0 and random.random() < self.invert_proba:
329
+ result = 1 - result
330
+ return result
331
+
332
+
333
+ def get_mask_generator(kind, kwargs):
334
+ if kind is None:
335
+ kind = "mixed"
336
+ if kwargs is None:
337
+ kwargs = {}
338
+
339
+ if kind == "mixed":
340
+ cl = MixedMaskGenerator
341
+ elif kind == "outpainting":
342
+ cl = OutpaintingMaskGenerator
343
+ elif kind == "dumb":
344
+ cl = DumbAreaMaskGenerator
345
+ else:
346
+ raise NotImplementedError(f"No such generator kind = {kind}")
347
+ return cl(**kwargs)
t2i/controlnet_union/models/controlnet_union.py CHANGED
@@ -1,957 +1,957 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Any, Dict, List, Optional, Tuple, Union
16
-
17
- import torch
18
- from torch import nn
19
- from torch.nn import functional as F
20
-
21
- from diffusers.configuration_utils import ConfigMixin, register_to_config
22
- from diffusers.loaders.single_file_model import FromOriginalModelMixin
23
- from diffusers.utils import BaseOutput, logging
24
- from diffusers.models.attention_processor import (
25
- ADDED_KV_ATTENTION_PROCESSORS,
26
- CROSS_ATTENTION_PROCESSORS,
27
- AttentionProcessor,
28
- AttnAddedKVProcessor,
29
- AttnProcessor,
30
- )
31
- from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
32
- from diffusers.models.modeling_utils import ModelMixin
33
- from diffusers.models.unets.unet_2d_blocks import (
34
- CrossAttnDownBlock2D,
35
- DownBlock2D,
36
- UNetMidBlock2DCrossAttn,
37
- get_down_block,
38
- )
39
- from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
40
-
41
-
42
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
-
44
-
45
- from collections import OrderedDict
46
-
47
- # Transformer Block
48
- # Used to exchange info between different conditions and input image
49
- # With reference to https://github.com/TencentARC/T2I-Adapter/blob/SD/ldm/modules/encoders/adapter.py#L147
50
- class QuickGELU(nn.Module):
51
-
52
- def forward(self, x: torch.Tensor):
53
- return x * torch.sigmoid(1.702 * x)
54
-
55
- class LayerNorm(nn.LayerNorm):
56
- """Subclass torch's LayerNorm to handle fp16."""
57
-
58
- def forward(self, x: torch.Tensor):
59
- orig_type = x.dtype
60
- ret = super().forward(x)
61
- return ret.type(orig_type)
62
-
63
- class ResidualAttentionBlock(nn.Module):
64
-
65
- def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
66
- super().__init__()
67
-
68
- self.attn = nn.MultiheadAttention(d_model, n_head)
69
- self.ln_1 = LayerNorm(d_model)
70
- self.mlp = nn.Sequential(
71
- OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()),
72
- ("c_proj", nn.Linear(d_model * 4, d_model))]))
73
- self.ln_2 = LayerNorm(d_model)
74
- self.attn_mask = attn_mask
75
-
76
- def attention(self, x: torch.Tensor):
77
- self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
78
- return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
79
-
80
- def forward(self, x: torch.Tensor):
81
- x = x + self.attention(self.ln_1(x))
82
- x = x + self.mlp(self.ln_2(x))
83
- return x
84
- #-----------------------------------------------------------------------------------------------------
85
-
86
- @dataclass
87
- class ControlNetOutput(BaseOutput):
88
- """
89
- The output of [`ControlNetModel`].
90
-
91
- Args:
92
- down_block_res_samples (`tuple[torch.Tensor]`):
93
- A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
94
- be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
95
- used to condition the original UNet's downsampling activations.
96
- mid_down_block_re_sample (`torch.Tensor`):
97
- The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
98
- `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
99
- Output can be used to condition the original UNet's middle block activation.
100
- """
101
-
102
- down_block_res_samples: Tuple[torch.Tensor]
103
- mid_block_res_sample: torch.Tensor
104
-
105
-
106
- class ControlNetConditioningEmbedding(nn.Module):
107
- """
108
- Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
109
- [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
110
- training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
111
- convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
112
- (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
113
- model) to encode image-space conditions ... into feature maps ..."
114
- """
115
-
116
- # original setting is (16, 32, 96, 256)
117
- def __init__(
118
- self,
119
- conditioning_embedding_channels: int,
120
- conditioning_channels: int = 3,
121
- block_out_channels: Tuple[int] = (48, 96, 192, 384),
122
- ):
123
- super().__init__()
124
-
125
- self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
126
-
127
- self.blocks = nn.ModuleList([])
128
-
129
- for i in range(len(block_out_channels) - 1):
130
- channel_in = block_out_channels[i]
131
- channel_out = block_out_channels[i + 1]
132
- self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
133
- self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
134
-
135
- self.conv_out = zero_module(
136
- nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
137
- )
138
-
139
- def forward(self, conditioning):
140
- embedding = self.conv_in(conditioning)
141
- embedding = F.silu(embedding)
142
-
143
- for block in self.blocks:
144
- embedding = block(embedding)
145
- embedding = F.silu(embedding)
146
-
147
- embedding = self.conv_out(embedding)
148
-
149
- return embedding
150
-
151
-
152
- class ControlNetModel_Union(ModelMixin, ConfigMixin, FromOriginalModelMixin):
153
- """
154
- A ControlNet model.
155
-
156
- Args:
157
- in_channels (`int`, defaults to 4):
158
- The number of channels in the input sample.
159
- flip_sin_to_cos (`bool`, defaults to `True`):
160
- Whether to flip the sin to cos in the time embedding.
161
- freq_shift (`int`, defaults to 0):
162
- The frequency shift to apply to the time embedding.
163
- down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
164
- The tuple of downsample blocks to use.
165
- only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
166
- block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
167
- The tuple of output channels for each block.
168
- layers_per_block (`int`, defaults to 2):
169
- The number of layers per block.
170
- downsample_padding (`int`, defaults to 1):
171
- The padding to use for the downsampling convolution.
172
- mid_block_scale_factor (`float`, defaults to 1):
173
- The scale factor to use for the mid block.
174
- act_fn (`str`, defaults to "silu"):
175
- The activation function to use.
176
- norm_num_groups (`int`, *optional*, defaults to 32):
177
- The number of groups to use for the normalization. If None, normalization and activation layers is skipped
178
- in post-processing.
179
- norm_eps (`float`, defaults to 1e-5):
180
- The epsilon to use for the normalization.
181
- cross_attention_dim (`int`, defaults to 1280):
182
- The dimension of the cross attention features.
183
- transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
184
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
185
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
186
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
187
- encoder_hid_dim (`int`, *optional*, defaults to None):
188
- If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
189
- dimension to `cross_attention_dim`.
190
- encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
191
- If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
192
- embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
193
- attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
194
- The dimension of the attention heads.
195
- use_linear_projection (`bool`, defaults to `False`):
196
- class_embed_type (`str`, *optional*, defaults to `None`):
197
- The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
198
- `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
199
- addition_embed_type (`str`, *optional*, defaults to `None`):
200
- Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
201
- "text". "text" will use the `TextTimeEmbedding` layer.
202
- num_class_embeds (`int`, *optional*, defaults to 0):
203
- Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
204
- class conditioning with `class_embed_type` equal to `None`.
205
- upcast_attention (`bool`, defaults to `False`):
206
- resnet_time_scale_shift (`str`, defaults to `"default"`):
207
- Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
208
- projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
209
- The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
210
- `class_embed_type="projection"`.
211
- controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
212
- The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
213
- conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
214
- The tuple of output channel for each block in the `conditioning_embedding` layer.
215
- global_pool_conditions (`bool`, defaults to `False`):
216
- """
217
-
218
- _supports_gradient_checkpointing = True
219
-
220
- @register_to_config
221
- def __init__(
222
- self,
223
- in_channels: int = 4,
224
- conditioning_channels: int = 3,
225
- flip_sin_to_cos: bool = True,
226
- freq_shift: int = 0,
227
- down_block_types: Tuple[str] = (
228
- "CrossAttnDownBlock2D",
229
- "CrossAttnDownBlock2D",
230
- "CrossAttnDownBlock2D",
231
- "DownBlock2D",
232
- ),
233
- only_cross_attention: Union[bool, Tuple[bool]] = False,
234
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
235
- layers_per_block: int = 2,
236
- downsample_padding: int = 1,
237
- mid_block_scale_factor: float = 1,
238
- act_fn: str = "silu",
239
- norm_num_groups: Optional[int] = 32,
240
- norm_eps: float = 1e-5,
241
- cross_attention_dim: int = 1280,
242
- transformer_layers_per_block: Union[int, Tuple[int]] = 1,
243
- encoder_hid_dim: Optional[int] = None,
244
- encoder_hid_dim_type: Optional[str] = None,
245
- attention_head_dim: Union[int, Tuple[int]] = 8,
246
- num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
247
- use_linear_projection: bool = False,
248
- class_embed_type: Optional[str] = None,
249
- addition_embed_type: Optional[str] = None,
250
- addition_time_embed_dim: Optional[int] = None,
251
- num_class_embeds: Optional[int] = None,
252
- upcast_attention: bool = False,
253
- resnet_time_scale_shift: str = "default",
254
- projection_class_embeddings_input_dim: Optional[int] = None,
255
- controlnet_conditioning_channel_order: str = "rgb",
256
- conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
257
- global_pool_conditions: bool = False,
258
- addition_embed_type_num_heads=64,
259
- num_control_type = 6,
260
- ):
261
- super().__init__()
262
-
263
- # If `num_attention_heads` is not defined (which is the case for most models)
264
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
265
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
266
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
267
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
268
- # which is why we correct for the naming here.
269
- num_attention_heads = num_attention_heads or attention_head_dim
270
-
271
- # Check inputs
272
- if len(block_out_channels) != len(down_block_types):
273
- raise ValueError(
274
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
275
- )
276
-
277
- if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
278
- raise ValueError(
279
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
280
- )
281
-
282
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
283
- raise ValueError(
284
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
285
- )
286
-
287
- if isinstance(transformer_layers_per_block, int):
288
- transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
289
-
290
- # input
291
- conv_in_kernel = 3
292
- conv_in_padding = (conv_in_kernel - 1) // 2
293
- self.conv_in = nn.Conv2d(
294
- in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
295
- )
296
-
297
- # time
298
- time_embed_dim = block_out_channels[0] * 4
299
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
300
- timestep_input_dim = block_out_channels[0]
301
- self.time_embedding = TimestepEmbedding(
302
- timestep_input_dim,
303
- time_embed_dim,
304
- act_fn=act_fn,
305
- )
306
-
307
- if encoder_hid_dim_type is None and encoder_hid_dim is not None:
308
- encoder_hid_dim_type = "text_proj"
309
- self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
310
- logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
311
-
312
- if encoder_hid_dim is None and encoder_hid_dim_type is not None:
313
- raise ValueError(
314
- f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
315
- )
316
-
317
- if encoder_hid_dim_type == "text_proj":
318
- self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
319
- elif encoder_hid_dim_type == "text_image_proj":
320
- # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
321
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
322
- # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
323
- self.encoder_hid_proj = TextImageProjection(
324
- text_embed_dim=encoder_hid_dim,
325
- image_embed_dim=cross_attention_dim,
326
- cross_attention_dim=cross_attention_dim,
327
- )
328
-
329
- elif encoder_hid_dim_type is not None:
330
- raise ValueError(
331
- f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
332
- )
333
- else:
334
- self.encoder_hid_proj = None
335
-
336
- # class embedding
337
- if class_embed_type is None and num_class_embeds is not None:
338
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
339
- elif class_embed_type == "timestep":
340
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
341
- elif class_embed_type == "identity":
342
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
343
- elif class_embed_type == "projection":
344
- if projection_class_embeddings_input_dim is None:
345
- raise ValueError(
346
- "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
347
- )
348
- # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
349
- # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
350
- # 2. it projects from an arbitrary input dimension.
351
- #
352
- # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
353
- # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
354
- # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
355
- self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
356
- else:
357
- self.class_embedding = None
358
-
359
- if addition_embed_type == "text":
360
- if encoder_hid_dim is not None:
361
- text_time_embedding_from_dim = encoder_hid_dim
362
- else:
363
- text_time_embedding_from_dim = cross_attention_dim
364
-
365
- self.add_embedding = TextTimeEmbedding(
366
- text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
367
- )
368
- elif addition_embed_type == "text_image":
369
- # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
370
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
371
- # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
372
- self.add_embedding = TextImageTimeEmbedding(
373
- text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
374
- )
375
- elif addition_embed_type == "text_time":
376
- self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
377
- self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
378
-
379
- elif addition_embed_type is not None:
380
- raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
381
-
382
- # control net conditioning embedding
383
- self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
384
- conditioning_embedding_channels=block_out_channels[0],
385
- block_out_channels=conditioning_embedding_out_channels,
386
- conditioning_channels=conditioning_channels,
387
- )
388
-
389
- # Copyright by Qi Xin(2024/07/06)
390
- # Condition Transformer(fuse single/multi conditions with input image)
391
- # The Condition Transformer augment the feature representation of conditions
392
- # The overall design is somewhat like resnet. The output of Condition Transformer is used to predict a condition bias adding to the original condition feature.
393
- # num_control_type = 6
394
- num_trans_channel = 320
395
- num_trans_head = 8
396
- num_trans_layer = 1
397
- num_proj_channel = 320
398
- task_scale_factor = num_trans_channel ** 0.5
399
-
400
- self.task_embedding = nn.Parameter(task_scale_factor * torch.randn(num_control_type, num_trans_channel))
401
- self.transformer_layes = nn.Sequential(*[ResidualAttentionBlock(num_trans_channel, num_trans_head) for _ in range(num_trans_layer)])
402
- self.spatial_ch_projs = zero_module(nn.Linear(num_trans_channel, num_proj_channel))
403
- #-----------------------------------------------------------------------------------------------------
404
-
405
- # Copyright by Qi Xin(2024/07/06)
406
- # Control Encoder to distinguish different control conditions
407
- # A simple but effective module, consists of an embedding layer and a linear layer, to inject the control info to time embedding.
408
- self.control_type_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
409
- self.control_add_embedding = TimestepEmbedding(addition_time_embed_dim * num_control_type, time_embed_dim)
410
- #-----------------------------------------------------------------------------------------------------
411
-
412
- self.down_blocks = nn.ModuleList([])
413
- self.controlnet_down_blocks = nn.ModuleList([])
414
-
415
- if isinstance(only_cross_attention, bool):
416
- only_cross_attention = [only_cross_attention] * len(down_block_types)
417
-
418
- if isinstance(attention_head_dim, int):
419
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
420
-
421
- if isinstance(num_attention_heads, int):
422
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
423
-
424
- # down
425
- output_channel = block_out_channels[0]
426
-
427
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
428
- controlnet_block = zero_module(controlnet_block)
429
- self.controlnet_down_blocks.append(controlnet_block)
430
-
431
- for i, down_block_type in enumerate(down_block_types):
432
- input_channel = output_channel
433
- output_channel = block_out_channels[i]
434
- is_final_block = i == len(block_out_channels) - 1
435
-
436
- down_block = get_down_block(
437
- down_block_type,
438
- num_layers=layers_per_block,
439
- transformer_layers_per_block=transformer_layers_per_block[i],
440
- in_channels=input_channel,
441
- out_channels=output_channel,
442
- temb_channels=time_embed_dim,
443
- add_downsample=not is_final_block,
444
- resnet_eps=norm_eps,
445
- resnet_act_fn=act_fn,
446
- resnet_groups=norm_num_groups,
447
- cross_attention_dim=cross_attention_dim,
448
- num_attention_heads=num_attention_heads[i],
449
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
450
- downsample_padding=downsample_padding,
451
- use_linear_projection=use_linear_projection,
452
- only_cross_attention=only_cross_attention[i],
453
- upcast_attention=upcast_attention,
454
- resnet_time_scale_shift=resnet_time_scale_shift,
455
- )
456
- self.down_blocks.append(down_block)
457
-
458
- for _ in range(layers_per_block):
459
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
460
- controlnet_block = zero_module(controlnet_block)
461
- self.controlnet_down_blocks.append(controlnet_block)
462
-
463
- if not is_final_block:
464
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
465
- controlnet_block = zero_module(controlnet_block)
466
- self.controlnet_down_blocks.append(controlnet_block)
467
-
468
- # mid
469
- mid_block_channel = block_out_channels[-1]
470
-
471
- controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
472
- controlnet_block = zero_module(controlnet_block)
473
- self.controlnet_mid_block = controlnet_block
474
-
475
- self.mid_block = UNetMidBlock2DCrossAttn(
476
- transformer_layers_per_block=transformer_layers_per_block[-1],
477
- in_channels=mid_block_channel,
478
- temb_channels=time_embed_dim,
479
- resnet_eps=norm_eps,
480
- resnet_act_fn=act_fn,
481
- output_scale_factor=mid_block_scale_factor,
482
- resnet_time_scale_shift=resnet_time_scale_shift,
483
- cross_attention_dim=cross_attention_dim,
484
- num_attention_heads=num_attention_heads[-1],
485
- resnet_groups=norm_num_groups,
486
- use_linear_projection=use_linear_projection,
487
- upcast_attention=upcast_attention,
488
- )
489
-
490
- @classmethod
491
- def from_unet(
492
- cls,
493
- unet: UNet2DConditionModel,
494
- controlnet_conditioning_channel_order: str = "rgb",
495
- conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
496
- load_weights_from_unet: bool = True,
497
- ):
498
- r"""
499
- Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
500
-
501
- Parameters:
502
- unet (`UNet2DConditionModel`):
503
- The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
504
- where applicable.
505
- """
506
- transformer_layers_per_block = (
507
- unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
508
- )
509
- encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
510
- encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
511
- addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
512
- addition_time_embed_dim = (
513
- unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
514
- )
515
-
516
- controlnet = cls(
517
- encoder_hid_dim=encoder_hid_dim,
518
- encoder_hid_dim_type=encoder_hid_dim_type,
519
- addition_embed_type=addition_embed_type,
520
- addition_time_embed_dim=addition_time_embed_dim,
521
- transformer_layers_per_block=transformer_layers_per_block,
522
- # transformer_layers_per_block=[1, 2, 5],
523
- in_channels=unet.config.in_channels,
524
- flip_sin_to_cos=unet.config.flip_sin_to_cos,
525
- freq_shift=unet.config.freq_shift,
526
- down_block_types=unet.config.down_block_types,
527
- only_cross_attention=unet.config.only_cross_attention,
528
- block_out_channels=unet.config.block_out_channels,
529
- layers_per_block=unet.config.layers_per_block,
530
- downsample_padding=unet.config.downsample_padding,
531
- mid_block_scale_factor=unet.config.mid_block_scale_factor,
532
- act_fn=unet.config.act_fn,
533
- norm_num_groups=unet.config.norm_num_groups,
534
- norm_eps=unet.config.norm_eps,
535
- cross_attention_dim=unet.config.cross_attention_dim,
536
- attention_head_dim=unet.config.attention_head_dim,
537
- num_attention_heads=unet.config.num_attention_heads,
538
- use_linear_projection=unet.config.use_linear_projection,
539
- class_embed_type=unet.config.class_embed_type,
540
- num_class_embeds=unet.config.num_class_embeds,
541
- upcast_attention=unet.config.upcast_attention,
542
- resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
543
- projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
544
- controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
545
- conditioning_embedding_out_channels=conditioning_embedding_out_channels,
546
- )
547
-
548
- if load_weights_from_unet:
549
- controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
550
- controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
551
- controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
552
-
553
- if controlnet.class_embedding:
554
- controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
555
-
556
- controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
557
- controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
558
-
559
- return controlnet
560
-
561
- @property
562
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
563
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
564
- r"""
565
- Returns:
566
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
567
- indexed by its weight name.
568
- """
569
- # set recursively
570
- processors = {}
571
-
572
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
573
- if hasattr(module, "get_processor"):
574
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
575
-
576
- for sub_name, child in module.named_children():
577
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
578
-
579
- return processors
580
-
581
- for name, module in self.named_children():
582
- fn_recursive_add_processors(name, module, processors)
583
-
584
- return processors
585
-
586
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
587
- def set_attn_processor(
588
- self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
589
- ):
590
- r"""
591
- Sets the attention processor to use to compute attention.
592
-
593
- Parameters:
594
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
595
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
596
- for **all** `Attention` layers.
597
-
598
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
599
- processor. This is strongly recommended when setting trainable attention processors.
600
-
601
- """
602
- count = len(self.attn_processors.keys())
603
-
604
- if isinstance(processor, dict) and len(processor) != count:
605
- raise ValueError(
606
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
607
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
608
- )
609
-
610
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
611
- if hasattr(module, "set_processor"):
612
- if not isinstance(processor, dict):
613
- module.set_processor(processor, _remove_lora=_remove_lora)
614
- else:
615
- module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
616
-
617
- for sub_name, child in module.named_children():
618
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
619
-
620
- for name, module in self.named_children():
621
- fn_recursive_attn_processor(name, module, processor)
622
-
623
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
624
- def set_default_attn_processor(self):
625
- """
626
- Disables custom attention processors and sets the default attention implementation.
627
- """
628
- if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
629
- processor = AttnAddedKVProcessor()
630
- elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
631
- processor = AttnProcessor()
632
- else:
633
- raise ValueError(
634
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
635
- )
636
-
637
- self.set_attn_processor(processor, _remove_lora=True)
638
-
639
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
640
- def set_attention_slice(self, slice_size):
641
- r"""
642
- Enable sliced attention computation.
643
-
644
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
645
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
646
-
647
- Args:
648
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
649
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
650
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
651
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
652
- must be a multiple of `slice_size`.
653
- """
654
- sliceable_head_dims = []
655
-
656
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
657
- if hasattr(module, "set_attention_slice"):
658
- sliceable_head_dims.append(module.sliceable_head_dim)
659
-
660
- for child in module.children():
661
- fn_recursive_retrieve_sliceable_dims(child)
662
-
663
- # retrieve number of attention layers
664
- for module in self.children():
665
- fn_recursive_retrieve_sliceable_dims(module)
666
-
667
- num_sliceable_layers = len(sliceable_head_dims)
668
-
669
- if slice_size == "auto":
670
- # half the attention head size is usually a good trade-off between
671
- # speed and memory
672
- slice_size = [dim // 2 for dim in sliceable_head_dims]
673
- elif slice_size == "max":
674
- # make smallest slice possible
675
- slice_size = num_sliceable_layers * [1]
676
-
677
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
678
-
679
- if len(slice_size) != len(sliceable_head_dims):
680
- raise ValueError(
681
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
682
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
683
- )
684
-
685
- for i in range(len(slice_size)):
686
- size = slice_size[i]
687
- dim = sliceable_head_dims[i]
688
- if size is not None and size > dim:
689
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
690
-
691
- # Recursively walk through all the children.
692
- # Any children which exposes the set_attention_slice method
693
- # gets the message
694
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
695
- if hasattr(module, "set_attention_slice"):
696
- module.set_attention_slice(slice_size.pop())
697
-
698
- for child in module.children():
699
- fn_recursive_set_attention_slice(child, slice_size)
700
-
701
- reversed_slice_size = list(reversed(slice_size))
702
- for module in self.children():
703
- fn_recursive_set_attention_slice(module, reversed_slice_size)
704
-
705
-
706
- def _set_gradient_checkpointing(self, module, value=False):
707
- if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
708
- module.gradient_checkpointing = value
709
-
710
-
711
- def forward(
712
- self,
713
- sample: torch.FloatTensor,
714
- timestep: Union[torch.Tensor, float, int],
715
- encoder_hidden_states: torch.Tensor,
716
- controlnet_cond_list: torch.FloatTensor,
717
- conditioning_scale: float = 1.0,
718
- class_labels: Optional[torch.Tensor] = None,
719
- timestep_cond: Optional[torch.Tensor] = None,
720
- attention_mask: Optional[torch.Tensor] = None,
721
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
722
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
723
- guess_mode: bool = False,
724
- return_dict: bool = True,
725
- ) -> Union[ControlNetOutput, Tuple]:
726
- """
727
- The [`ControlNetModel`] forward method.
728
-
729
- Args:
730
- sample (`torch.FloatTensor`):
731
- The noisy input tensor.
732
- timestep (`Union[torch.Tensor, float, int]`):
733
- The number of timesteps to denoise an input.
734
- encoder_hidden_states (`torch.Tensor`):
735
- The encoder hidden states.
736
- controlnet_cond (`torch.FloatTensor`):
737
- The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
738
- conditioning_scale (`float`, defaults to `1.0`):
739
- The scale factor for ControlNet outputs.
740
- class_labels (`torch.Tensor`, *optional*, defaults to `None`):
741
- Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
742
- timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
743
- Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
744
- timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
745
- embeddings.
746
- attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
747
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
748
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
749
- negative values to the attention scores corresponding to "discard" tokens.
750
- added_cond_kwargs (`dict`):
751
- Additional conditions for the Stable Diffusion XL UNet.
752
- cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
753
- A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
754
- guess_mode (`bool`, defaults to `False`):
755
- In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
756
- you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
757
- return_dict (`bool`, defaults to `True`):
758
- Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
759
-
760
- Returns:
761
- [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
762
- If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
763
- returned where the first element is the sample tensor.
764
- """
765
- # check channel order
766
- channel_order = self.config.controlnet_conditioning_channel_order
767
-
768
- if channel_order == "rgb":
769
- # in rgb order by default
770
- ...
771
- # elif channel_order == "bgr":
772
- # controlnet_cond = torch.flip(controlnet_cond, dims=[1])
773
- else:
774
- raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
775
-
776
- # prepare attention_mask
777
- if attention_mask is not None:
778
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
779
- attention_mask = attention_mask.unsqueeze(1)
780
-
781
- # 1. time
782
- timesteps = timestep
783
- if not torch.is_tensor(timesteps):
784
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
785
- # This would be a good case for the `match` statement (Python 3.10+)
786
- is_mps = sample.device.type == "mps"
787
- if isinstance(timestep, float):
788
- dtype = torch.float32 if is_mps else torch.float64
789
- else:
790
- dtype = torch.int32 if is_mps else torch.int64
791
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
792
- elif len(timesteps.shape) == 0:
793
- timesteps = timesteps[None].to(sample.device)
794
-
795
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
796
- timesteps = timesteps.expand(sample.shape[0])
797
-
798
- t_emb = self.time_proj(timesteps)
799
-
800
- # timesteps does not contain any weights and will always return f32 tensors
801
- # but time_embedding might actually be running in fp16. so we need to cast here.
802
- # there might be better ways to encapsulate this.
803
- t_emb = t_emb.to(dtype=sample.dtype)
804
-
805
- emb = self.time_embedding(t_emb, timestep_cond)
806
- aug_emb = None
807
-
808
- if self.class_embedding is not None:
809
- if class_labels is None:
810
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
811
-
812
- if self.config.class_embed_type == "timestep":
813
- class_labels = self.time_proj(class_labels)
814
-
815
- class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
816
- emb = emb + class_emb
817
-
818
- if self.config.addition_embed_type is not None:
819
- if self.config.addition_embed_type == "text":
820
- aug_emb = self.add_embedding(encoder_hidden_states)
821
-
822
- elif self.config.addition_embed_type == "text_time":
823
- if "text_embeds" not in added_cond_kwargs:
824
- raise ValueError(
825
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
826
- )
827
- text_embeds = added_cond_kwargs.get("text_embeds")
828
- if "time_ids" not in added_cond_kwargs:
829
- raise ValueError(
830
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
831
- )
832
- time_ids = added_cond_kwargs.get("time_ids")
833
- time_embeds = self.add_time_proj(time_ids.flatten())
834
- time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
835
-
836
- add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
837
- add_embeds = add_embeds.to(emb.dtype)
838
- aug_emb = self.add_embedding(add_embeds)
839
-
840
- # Copyright by Qi Xin(2024/07/06)
841
- # inject control type info to time embedding to distinguish different control conditions
842
- control_type = added_cond_kwargs.get('control_type')
843
- control_embeds = self.control_type_proj(control_type.flatten())
844
- control_embeds = control_embeds.reshape((t_emb.shape[0], -1))
845
- control_embeds = control_embeds.to(emb.dtype)
846
- control_emb = self.control_add_embedding(control_embeds)
847
- emb = emb + control_emb
848
- #---------------------------------------------------------------------------------
849
-
850
- emb = emb + aug_emb if aug_emb is not None else emb
851
-
852
- # 2. pre-process
853
- sample = self.conv_in(sample)
854
- indices = torch.nonzero(control_type[0])
855
-
856
- # Copyright by Qi Xin(2024/07/06)
857
- # add single/multi conditons to input image.
858
- # Condition Transformer provides an easy and effective way to fuse different features naturally
859
- inputs = []
860
- condition_list = []
861
-
862
- for idx in range(indices.shape[0] + 1):
863
- if idx == indices.shape[0]:
864
- controlnet_cond = sample
865
- feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) # N * C
866
- else:
867
- controlnet_cond = self.controlnet_cond_embedding(controlnet_cond_list[indices[idx][0]])
868
- feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) # N * C
869
- feat_seq = feat_seq + self.task_embedding[indices[idx][0]]
870
-
871
- inputs.append(feat_seq.unsqueeze(1))
872
- condition_list.append(controlnet_cond)
873
-
874
- x = torch.cat(inputs, dim=1) # NxLxC
875
- x = self.transformer_layes(x)
876
-
877
- controlnet_cond_fuser = sample * 0.0
878
- for idx in range(indices.shape[0]):
879
- alpha = self.spatial_ch_projs(x[:, idx])
880
- alpha = alpha.unsqueeze(-1).unsqueeze(-1)
881
- controlnet_cond_fuser += condition_list[idx] + alpha
882
-
883
- sample = sample + controlnet_cond_fuser
884
- #-------------------------------------------------------------------------------------------
885
-
886
- # 3. down
887
- down_block_res_samples = (sample,)
888
- for downsample_block in self.down_blocks:
889
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
890
- sample, res_samples = downsample_block(
891
- hidden_states=sample,
892
- temb=emb,
893
- encoder_hidden_states=encoder_hidden_states,
894
- attention_mask=attention_mask,
895
- cross_attention_kwargs=cross_attention_kwargs,
896
- )
897
- else:
898
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
899
-
900
- down_block_res_samples += res_samples
901
-
902
- # 4. mid
903
- if self.mid_block is not None:
904
- sample = self.mid_block(
905
- sample,
906
- emb,
907
- encoder_hidden_states=encoder_hidden_states,
908
- attention_mask=attention_mask,
909
- cross_attention_kwargs=cross_attention_kwargs,
910
- )
911
-
912
- # 5. Control net blocks
913
-
914
- controlnet_down_block_res_samples = ()
915
-
916
- for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
917
- down_block_res_sample = controlnet_block(down_block_res_sample)
918
- controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
919
-
920
- down_block_res_samples = controlnet_down_block_res_samples
921
-
922
- mid_block_res_sample = self.controlnet_mid_block(sample)
923
-
924
- # 6. scaling
925
- if guess_mode and not self.config.global_pool_conditions:
926
- scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
927
- scales = scales * conditioning_scale
928
- down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
929
- mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
930
- else:
931
- down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
932
- mid_block_res_sample = mid_block_res_sample * conditioning_scale
933
-
934
- if self.config.global_pool_conditions:
935
- down_block_res_samples = [
936
- torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
937
- ]
938
- mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
939
-
940
- if not return_dict:
941
- return (down_block_res_samples, mid_block_res_sample)
942
-
943
- return ControlNetOutput(
944
- down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
945
- )
946
-
947
-
948
-
949
- def zero_module(module):
950
- for p in module.parameters():
951
- nn.init.zeros_(p)
952
- return module
953
-
954
-
955
-
956
-
957
-
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.attention_processor import (
25
+ ADDED_KV_ATTENTION_PROCESSORS,
26
+ CROSS_ATTENTION_PROCESSORS,
27
+ AttentionProcessor,
28
+ AttnAddedKVProcessor,
29
+ AttnProcessor,
30
+ )
31
+ from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.unets.unet_2d_blocks import (
34
+ CrossAttnDownBlock2D,
35
+ DownBlock2D,
36
+ UNetMidBlock2DCrossAttn,
37
+ get_down_block,
38
+ )
39
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
40
+
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+
45
+ from collections import OrderedDict
46
+
47
+ # Transformer Block
48
+ # Used to exchange info between different conditions and input image
49
+ # With reference to https://github.com/TencentARC/T2I-Adapter/blob/SD/ldm/modules/encoders/adapter.py#L147
50
+ class QuickGELU(nn.Module):
51
+
52
+ def forward(self, x: torch.Tensor):
53
+ return x * torch.sigmoid(1.702 * x)
54
+
55
+ class LayerNorm(nn.LayerNorm):
56
+ """Subclass torch's LayerNorm to handle fp16."""
57
+
58
+ def forward(self, x: torch.Tensor):
59
+ orig_type = x.dtype
60
+ ret = super().forward(x)
61
+ return ret.type(orig_type)
62
+
63
+ class ResidualAttentionBlock(nn.Module):
64
+
65
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
66
+ super().__init__()
67
+
68
+ self.attn = nn.MultiheadAttention(d_model, n_head)
69
+ self.ln_1 = LayerNorm(d_model)
70
+ self.mlp = nn.Sequential(
71
+ OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()),
72
+ ("c_proj", nn.Linear(d_model * 4, d_model))]))
73
+ self.ln_2 = LayerNorm(d_model)
74
+ self.attn_mask = attn_mask
75
+
76
+ def attention(self, x: torch.Tensor):
77
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
78
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
79
+
80
+ def forward(self, x: torch.Tensor):
81
+ x = x + self.attention(self.ln_1(x))
82
+ x = x + self.mlp(self.ln_2(x))
83
+ return x
84
+ #-----------------------------------------------------------------------------------------------------
85
+
86
+ @dataclass
87
+ class ControlNetOutput(BaseOutput):
88
+ """
89
+ The output of [`ControlNetModel`].
90
+
91
+ Args:
92
+ down_block_res_samples (`tuple[torch.Tensor]`):
93
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
94
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
95
+ used to condition the original UNet's downsampling activations.
96
+ mid_down_block_re_sample (`torch.Tensor`):
97
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
98
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
99
+ Output can be used to condition the original UNet's middle block activation.
100
+ """
101
+
102
+ down_block_res_samples: Tuple[torch.Tensor]
103
+ mid_block_res_sample: torch.Tensor
104
+
105
+
106
+ class ControlNetConditioningEmbedding(nn.Module):
107
+ """
108
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
109
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
110
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
111
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
112
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
113
+ model) to encode image-space conditions ... into feature maps ..."
114
+ """
115
+
116
+ # original setting is (16, 32, 96, 256)
117
+ def __init__(
118
+ self,
119
+ conditioning_embedding_channels: int,
120
+ conditioning_channels: int = 3,
121
+ block_out_channels: Tuple[int] = (48, 96, 192, 384),
122
+ ):
123
+ super().__init__()
124
+
125
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
126
+
127
+ self.blocks = nn.ModuleList([])
128
+
129
+ for i in range(len(block_out_channels) - 1):
130
+ channel_in = block_out_channels[i]
131
+ channel_out = block_out_channels[i + 1]
132
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
133
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
134
+
135
+ self.conv_out = zero_module(
136
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
137
+ )
138
+
139
+ def forward(self, conditioning):
140
+ embedding = self.conv_in(conditioning)
141
+ embedding = F.silu(embedding)
142
+
143
+ for block in self.blocks:
144
+ embedding = block(embedding)
145
+ embedding = F.silu(embedding)
146
+
147
+ embedding = self.conv_out(embedding)
148
+
149
+ return embedding
150
+
151
+
152
+ class ControlNetModel_Union(ModelMixin, ConfigMixin, FromOriginalModelMixin):
153
+ """
154
+ A ControlNet model.
155
+
156
+ Args:
157
+ in_channels (`int`, defaults to 4):
158
+ The number of channels in the input sample.
159
+ flip_sin_to_cos (`bool`, defaults to `True`):
160
+ Whether to flip the sin to cos in the time embedding.
161
+ freq_shift (`int`, defaults to 0):
162
+ The frequency shift to apply to the time embedding.
163
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
164
+ The tuple of downsample blocks to use.
165
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
166
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
167
+ The tuple of output channels for each block.
168
+ layers_per_block (`int`, defaults to 2):
169
+ The number of layers per block.
170
+ downsample_padding (`int`, defaults to 1):
171
+ The padding to use for the downsampling convolution.
172
+ mid_block_scale_factor (`float`, defaults to 1):
173
+ The scale factor to use for the mid block.
174
+ act_fn (`str`, defaults to "silu"):
175
+ The activation function to use.
176
+ norm_num_groups (`int`, *optional*, defaults to 32):
177
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
178
+ in post-processing.
179
+ norm_eps (`float`, defaults to 1e-5):
180
+ The epsilon to use for the normalization.
181
+ cross_attention_dim (`int`, defaults to 1280):
182
+ The dimension of the cross attention features.
183
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
184
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
185
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
186
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
187
+ encoder_hid_dim (`int`, *optional*, defaults to None):
188
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
189
+ dimension to `cross_attention_dim`.
190
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
191
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
192
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
193
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
194
+ The dimension of the attention heads.
195
+ use_linear_projection (`bool`, defaults to `False`):
196
+ class_embed_type (`str`, *optional*, defaults to `None`):
197
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
198
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
199
+ addition_embed_type (`str`, *optional*, defaults to `None`):
200
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
201
+ "text". "text" will use the `TextTimeEmbedding` layer.
202
+ num_class_embeds (`int`, *optional*, defaults to 0):
203
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
204
+ class conditioning with `class_embed_type` equal to `None`.
205
+ upcast_attention (`bool`, defaults to `False`):
206
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
207
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
208
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
209
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
210
+ `class_embed_type="projection"`.
211
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
212
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
213
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
214
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
215
+ global_pool_conditions (`bool`, defaults to `False`):
216
+ """
217
+
218
+ _supports_gradient_checkpointing = True
219
+
220
+ @register_to_config
221
+ def __init__(
222
+ self,
223
+ in_channels: int = 4,
224
+ conditioning_channels: int = 3,
225
+ flip_sin_to_cos: bool = True,
226
+ freq_shift: int = 0,
227
+ down_block_types: Tuple[str] = (
228
+ "CrossAttnDownBlock2D",
229
+ "CrossAttnDownBlock2D",
230
+ "CrossAttnDownBlock2D",
231
+ "DownBlock2D",
232
+ ),
233
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
234
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
235
+ layers_per_block: int = 2,
236
+ downsample_padding: int = 1,
237
+ mid_block_scale_factor: float = 1,
238
+ act_fn: str = "silu",
239
+ norm_num_groups: Optional[int] = 32,
240
+ norm_eps: float = 1e-5,
241
+ cross_attention_dim: int = 1280,
242
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
243
+ encoder_hid_dim: Optional[int] = None,
244
+ encoder_hid_dim_type: Optional[str] = None,
245
+ attention_head_dim: Union[int, Tuple[int]] = 8,
246
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
247
+ use_linear_projection: bool = False,
248
+ class_embed_type: Optional[str] = None,
249
+ addition_embed_type: Optional[str] = None,
250
+ addition_time_embed_dim: Optional[int] = None,
251
+ num_class_embeds: Optional[int] = None,
252
+ upcast_attention: bool = False,
253
+ resnet_time_scale_shift: str = "default",
254
+ projection_class_embeddings_input_dim: Optional[int] = None,
255
+ controlnet_conditioning_channel_order: str = "rgb",
256
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
257
+ global_pool_conditions: bool = False,
258
+ addition_embed_type_num_heads=64,
259
+ num_control_type = 6,
260
+ ):
261
+ super().__init__()
262
+
263
+ # If `num_attention_heads` is not defined (which is the case for most models)
264
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
265
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
266
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
267
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
268
+ # which is why we correct for the naming here.
269
+ num_attention_heads = num_attention_heads or attention_head_dim
270
+
271
+ # Check inputs
272
+ if len(block_out_channels) != len(down_block_types):
273
+ raise ValueError(
274
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
275
+ )
276
+
277
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
278
+ raise ValueError(
279
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
280
+ )
281
+
282
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
283
+ raise ValueError(
284
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
285
+ )
286
+
287
+ if isinstance(transformer_layers_per_block, int):
288
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
289
+
290
+ # input
291
+ conv_in_kernel = 3
292
+ conv_in_padding = (conv_in_kernel - 1) // 2
293
+ self.conv_in = nn.Conv2d(
294
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
295
+ )
296
+
297
+ # time
298
+ time_embed_dim = block_out_channels[0] * 4
299
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
300
+ timestep_input_dim = block_out_channels[0]
301
+ self.time_embedding = TimestepEmbedding(
302
+ timestep_input_dim,
303
+ time_embed_dim,
304
+ act_fn=act_fn,
305
+ )
306
+
307
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
308
+ encoder_hid_dim_type = "text_proj"
309
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
310
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
311
+
312
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
313
+ raise ValueError(
314
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
315
+ )
316
+
317
+ if encoder_hid_dim_type == "text_proj":
318
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
319
+ elif encoder_hid_dim_type == "text_image_proj":
320
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
321
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
322
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
323
+ self.encoder_hid_proj = TextImageProjection(
324
+ text_embed_dim=encoder_hid_dim,
325
+ image_embed_dim=cross_attention_dim,
326
+ cross_attention_dim=cross_attention_dim,
327
+ )
328
+
329
+ elif encoder_hid_dim_type is not None:
330
+ raise ValueError(
331
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
332
+ )
333
+ else:
334
+ self.encoder_hid_proj = None
335
+
336
+ # class embedding
337
+ if class_embed_type is None and num_class_embeds is not None:
338
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
339
+ elif class_embed_type == "timestep":
340
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
341
+ elif class_embed_type == "identity":
342
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
343
+ elif class_embed_type == "projection":
344
+ if projection_class_embeddings_input_dim is None:
345
+ raise ValueError(
346
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
347
+ )
348
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
349
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
350
+ # 2. it projects from an arbitrary input dimension.
351
+ #
352
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
353
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
354
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
355
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
356
+ else:
357
+ self.class_embedding = None
358
+
359
+ if addition_embed_type == "text":
360
+ if encoder_hid_dim is not None:
361
+ text_time_embedding_from_dim = encoder_hid_dim
362
+ else:
363
+ text_time_embedding_from_dim = cross_attention_dim
364
+
365
+ self.add_embedding = TextTimeEmbedding(
366
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
367
+ )
368
+ elif addition_embed_type == "text_image":
369
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
370
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
371
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
372
+ self.add_embedding = TextImageTimeEmbedding(
373
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
374
+ )
375
+ elif addition_embed_type == "text_time":
376
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
377
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
378
+
379
+ elif addition_embed_type is not None:
380
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
381
+
382
+ # control net conditioning embedding
383
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
384
+ conditioning_embedding_channels=block_out_channels[0],
385
+ block_out_channels=conditioning_embedding_out_channels,
386
+ conditioning_channels=conditioning_channels,
387
+ )
388
+
389
+ # Copyright by Qi Xin(2024/07/06)
390
+ # Condition Transformer(fuse single/multi conditions with input image)
391
+ # The Condition Transformer augment the feature representation of conditions
392
+ # The overall design is somewhat like resnet. The output of Condition Transformer is used to predict a condition bias adding to the original condition feature.
393
+ # num_control_type = 6
394
+ num_trans_channel = 320
395
+ num_trans_head = 8
396
+ num_trans_layer = 1
397
+ num_proj_channel = 320
398
+ task_scale_factor = num_trans_channel ** 0.5
399
+
400
+ self.task_embedding = nn.Parameter(task_scale_factor * torch.randn(num_control_type, num_trans_channel))
401
+ self.transformer_layes = nn.Sequential(*[ResidualAttentionBlock(num_trans_channel, num_trans_head) for _ in range(num_trans_layer)])
402
+ self.spatial_ch_projs = zero_module(nn.Linear(num_trans_channel, num_proj_channel))
403
+ #-----------------------------------------------------------------------------------------------------
404
+
405
+ # Copyright by Qi Xin(2024/07/06)
406
+ # Control Encoder to distinguish different control conditions
407
+ # A simple but effective module, consists of an embedding layer and a linear layer, to inject the control info to time embedding.
408
+ self.control_type_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
409
+ self.control_add_embedding = TimestepEmbedding(addition_time_embed_dim * num_control_type, time_embed_dim)
410
+ #-----------------------------------------------------------------------------------------------------
411
+
412
+ self.down_blocks = nn.ModuleList([])
413
+ self.controlnet_down_blocks = nn.ModuleList([])
414
+
415
+ if isinstance(only_cross_attention, bool):
416
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
417
+
418
+ if isinstance(attention_head_dim, int):
419
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
420
+
421
+ if isinstance(num_attention_heads, int):
422
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
423
+
424
+ # down
425
+ output_channel = block_out_channels[0]
426
+
427
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
428
+ controlnet_block = zero_module(controlnet_block)
429
+ self.controlnet_down_blocks.append(controlnet_block)
430
+
431
+ for i, down_block_type in enumerate(down_block_types):
432
+ input_channel = output_channel
433
+ output_channel = block_out_channels[i]
434
+ is_final_block = i == len(block_out_channels) - 1
435
+
436
+ down_block = get_down_block(
437
+ down_block_type,
438
+ num_layers=layers_per_block,
439
+ transformer_layers_per_block=transformer_layers_per_block[i],
440
+ in_channels=input_channel,
441
+ out_channels=output_channel,
442
+ temb_channels=time_embed_dim,
443
+ add_downsample=not is_final_block,
444
+ resnet_eps=norm_eps,
445
+ resnet_act_fn=act_fn,
446
+ resnet_groups=norm_num_groups,
447
+ cross_attention_dim=cross_attention_dim,
448
+ num_attention_heads=num_attention_heads[i],
449
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
450
+ downsample_padding=downsample_padding,
451
+ use_linear_projection=use_linear_projection,
452
+ only_cross_attention=only_cross_attention[i],
453
+ upcast_attention=upcast_attention,
454
+ resnet_time_scale_shift=resnet_time_scale_shift,
455
+ )
456
+ self.down_blocks.append(down_block)
457
+
458
+ for _ in range(layers_per_block):
459
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
460
+ controlnet_block = zero_module(controlnet_block)
461
+ self.controlnet_down_blocks.append(controlnet_block)
462
+
463
+ if not is_final_block:
464
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
465
+ controlnet_block = zero_module(controlnet_block)
466
+ self.controlnet_down_blocks.append(controlnet_block)
467
+
468
+ # mid
469
+ mid_block_channel = block_out_channels[-1]
470
+
471
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
472
+ controlnet_block = zero_module(controlnet_block)
473
+ self.controlnet_mid_block = controlnet_block
474
+
475
+ self.mid_block = UNetMidBlock2DCrossAttn(
476
+ transformer_layers_per_block=transformer_layers_per_block[-1],
477
+ in_channels=mid_block_channel,
478
+ temb_channels=time_embed_dim,
479
+ resnet_eps=norm_eps,
480
+ resnet_act_fn=act_fn,
481
+ output_scale_factor=mid_block_scale_factor,
482
+ resnet_time_scale_shift=resnet_time_scale_shift,
483
+ cross_attention_dim=cross_attention_dim,
484
+ num_attention_heads=num_attention_heads[-1],
485
+ resnet_groups=norm_num_groups,
486
+ use_linear_projection=use_linear_projection,
487
+ upcast_attention=upcast_attention,
488
+ )
489
+
490
+ @classmethod
491
+ def from_unet(
492
+ cls,
493
+ unet: UNet2DConditionModel,
494
+ controlnet_conditioning_channel_order: str = "rgb",
495
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
496
+ load_weights_from_unet: bool = True,
497
+ ):
498
+ r"""
499
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
500
+
501
+ Parameters:
502
+ unet (`UNet2DConditionModel`):
503
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
504
+ where applicable.
505
+ """
506
+ transformer_layers_per_block = (
507
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
508
+ )
509
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
510
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
511
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
512
+ addition_time_embed_dim = (
513
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
514
+ )
515
+
516
+ controlnet = cls(
517
+ encoder_hid_dim=encoder_hid_dim,
518
+ encoder_hid_dim_type=encoder_hid_dim_type,
519
+ addition_embed_type=addition_embed_type,
520
+ addition_time_embed_dim=addition_time_embed_dim,
521
+ transformer_layers_per_block=transformer_layers_per_block,
522
+ # transformer_layers_per_block=[1, 2, 5],
523
+ in_channels=unet.config.in_channels,
524
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
525
+ freq_shift=unet.config.freq_shift,
526
+ down_block_types=unet.config.down_block_types,
527
+ only_cross_attention=unet.config.only_cross_attention,
528
+ block_out_channels=unet.config.block_out_channels,
529
+ layers_per_block=unet.config.layers_per_block,
530
+ downsample_padding=unet.config.downsample_padding,
531
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
532
+ act_fn=unet.config.act_fn,
533
+ norm_num_groups=unet.config.norm_num_groups,
534
+ norm_eps=unet.config.norm_eps,
535
+ cross_attention_dim=unet.config.cross_attention_dim,
536
+ attention_head_dim=unet.config.attention_head_dim,
537
+ num_attention_heads=unet.config.num_attention_heads,
538
+ use_linear_projection=unet.config.use_linear_projection,
539
+ class_embed_type=unet.config.class_embed_type,
540
+ num_class_embeds=unet.config.num_class_embeds,
541
+ upcast_attention=unet.config.upcast_attention,
542
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
543
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
544
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
545
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
546
+ )
547
+
548
+ if load_weights_from_unet:
549
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
550
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
551
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
552
+
553
+ if controlnet.class_embedding:
554
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
555
+
556
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
557
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
558
+
559
+ return controlnet
560
+
561
+ @property
562
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
563
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
564
+ r"""
565
+ Returns:
566
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
567
+ indexed by its weight name.
568
+ """
569
+ # set recursively
570
+ processors = {}
571
+
572
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
573
+ if hasattr(module, "get_processor"):
574
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
575
+
576
+ for sub_name, child in module.named_children():
577
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
578
+
579
+ return processors
580
+
581
+ for name, module in self.named_children():
582
+ fn_recursive_add_processors(name, module, processors)
583
+
584
+ return processors
585
+
586
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
587
+ def set_attn_processor(
588
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
589
+ ):
590
+ r"""
591
+ Sets the attention processor to use to compute attention.
592
+
593
+ Parameters:
594
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
595
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
596
+ for **all** `Attention` layers.
597
+
598
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
599
+ processor. This is strongly recommended when setting trainable attention processors.
600
+
601
+ """
602
+ count = len(self.attn_processors.keys())
603
+
604
+ if isinstance(processor, dict) and len(processor) != count:
605
+ raise ValueError(
606
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
607
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
608
+ )
609
+
610
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
611
+ if hasattr(module, "set_processor"):
612
+ if not isinstance(processor, dict):
613
+ module.set_processor(processor, _remove_lora=_remove_lora)
614
+ else:
615
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
616
+
617
+ for sub_name, child in module.named_children():
618
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
619
+
620
+ for name, module in self.named_children():
621
+ fn_recursive_attn_processor(name, module, processor)
622
+
623
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
624
+ def set_default_attn_processor(self):
625
+ """
626
+ Disables custom attention processors and sets the default attention implementation.
627
+ """
628
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
629
+ processor = AttnAddedKVProcessor()
630
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
631
+ processor = AttnProcessor()
632
+ else:
633
+ raise ValueError(
634
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
635
+ )
636
+
637
+ self.set_attn_processor(processor, _remove_lora=True)
638
+
639
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
640
+ def set_attention_slice(self, slice_size):
641
+ r"""
642
+ Enable sliced attention computation.
643
+
644
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
645
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
646
+
647
+ Args:
648
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
649
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
650
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
651
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
652
+ must be a multiple of `slice_size`.
653
+ """
654
+ sliceable_head_dims = []
655
+
656
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
657
+ if hasattr(module, "set_attention_slice"):
658
+ sliceable_head_dims.append(module.sliceable_head_dim)
659
+
660
+ for child in module.children():
661
+ fn_recursive_retrieve_sliceable_dims(child)
662
+
663
+ # retrieve number of attention layers
664
+ for module in self.children():
665
+ fn_recursive_retrieve_sliceable_dims(module)
666
+
667
+ num_sliceable_layers = len(sliceable_head_dims)
668
+
669
+ if slice_size == "auto":
670
+ # half the attention head size is usually a good trade-off between
671
+ # speed and memory
672
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
673
+ elif slice_size == "max":
674
+ # make smallest slice possible
675
+ slice_size = num_sliceable_layers * [1]
676
+
677
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
678
+
679
+ if len(slice_size) != len(sliceable_head_dims):
680
+ raise ValueError(
681
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
682
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
683
+ )
684
+
685
+ for i in range(len(slice_size)):
686
+ size = slice_size[i]
687
+ dim = sliceable_head_dims[i]
688
+ if size is not None and size > dim:
689
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
690
+
691
+ # Recursively walk through all the children.
692
+ # Any children which exposes the set_attention_slice method
693
+ # gets the message
694
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
695
+ if hasattr(module, "set_attention_slice"):
696
+ module.set_attention_slice(slice_size.pop())
697
+
698
+ for child in module.children():
699
+ fn_recursive_set_attention_slice(child, slice_size)
700
+
701
+ reversed_slice_size = list(reversed(slice_size))
702
+ for module in self.children():
703
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
704
+
705
+
706
+ def _set_gradient_checkpointing(self, module, value=False):
707
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
708
+ module.gradient_checkpointing = value
709
+
710
+
711
+ def forward(
712
+ self,
713
+ sample: torch.FloatTensor,
714
+ timestep: Union[torch.Tensor, float, int],
715
+ encoder_hidden_states: torch.Tensor,
716
+ controlnet_cond_list: torch.FloatTensor,
717
+ conditioning_scale: float = 1.0,
718
+ class_labels: Optional[torch.Tensor] = None,
719
+ timestep_cond: Optional[torch.Tensor] = None,
720
+ attention_mask: Optional[torch.Tensor] = None,
721
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
722
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
723
+ guess_mode: bool = False,
724
+ return_dict: bool = True,
725
+ ) -> Union[ControlNetOutput, Tuple]:
726
+ """
727
+ The [`ControlNetModel`] forward method.
728
+
729
+ Args:
730
+ sample (`torch.FloatTensor`):
731
+ The noisy input tensor.
732
+ timestep (`Union[torch.Tensor, float, int]`):
733
+ The number of timesteps to denoise an input.
734
+ encoder_hidden_states (`torch.Tensor`):
735
+ The encoder hidden states.
736
+ controlnet_cond (`torch.FloatTensor`):
737
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
738
+ conditioning_scale (`float`, defaults to `1.0`):
739
+ The scale factor for ControlNet outputs.
740
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
741
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
742
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
743
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
744
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
745
+ embeddings.
746
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
747
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
748
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
749
+ negative values to the attention scores corresponding to "discard" tokens.
750
+ added_cond_kwargs (`dict`):
751
+ Additional conditions for the Stable Diffusion XL UNet.
752
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
753
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
754
+ guess_mode (`bool`, defaults to `False`):
755
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
756
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
757
+ return_dict (`bool`, defaults to `True`):
758
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
759
+
760
+ Returns:
761
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
762
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
763
+ returned where the first element is the sample tensor.
764
+ """
765
+ # check channel order
766
+ channel_order = self.config.controlnet_conditioning_channel_order
767
+
768
+ if channel_order == "rgb":
769
+ # in rgb order by default
770
+ ...
771
+ # elif channel_order == "bgr":
772
+ # controlnet_cond = torch.flip(controlnet_cond, dims=[1])
773
+ else:
774
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
775
+
776
+ # prepare attention_mask
777
+ if attention_mask is not None:
778
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
779
+ attention_mask = attention_mask.unsqueeze(1)
780
+
781
+ # 1. time
782
+ timesteps = timestep
783
+ if not torch.is_tensor(timesteps):
784
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
785
+ # This would be a good case for the `match` statement (Python 3.10+)
786
+ is_mps = sample.device.type == "mps"
787
+ if isinstance(timestep, float):
788
+ dtype = torch.float32 if is_mps else torch.float64
789
+ else:
790
+ dtype = torch.int32 if is_mps else torch.int64
791
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
792
+ elif len(timesteps.shape) == 0:
793
+ timesteps = timesteps[None].to(sample.device)
794
+
795
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
796
+ timesteps = timesteps.expand(sample.shape[0])
797
+
798
+ t_emb = self.time_proj(timesteps)
799
+
800
+ # timesteps does not contain any weights and will always return f32 tensors
801
+ # but time_embedding might actually be running in fp16. so we need to cast here.
802
+ # there might be better ways to encapsulate this.
803
+ t_emb = t_emb.to(dtype=sample.dtype)
804
+
805
+ emb = self.time_embedding(t_emb, timestep_cond)
806
+ aug_emb = None
807
+
808
+ if self.class_embedding is not None:
809
+ if class_labels is None:
810
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
811
+
812
+ if self.config.class_embed_type == "timestep":
813
+ class_labels = self.time_proj(class_labels)
814
+
815
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
816
+ emb = emb + class_emb
817
+
818
+ if self.config.addition_embed_type is not None:
819
+ if self.config.addition_embed_type == "text":
820
+ aug_emb = self.add_embedding(encoder_hidden_states)
821
+
822
+ elif self.config.addition_embed_type == "text_time":
823
+ if "text_embeds" not in added_cond_kwargs:
824
+ raise ValueError(
825
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
826
+ )
827
+ text_embeds = added_cond_kwargs.get("text_embeds")
828
+ if "time_ids" not in added_cond_kwargs:
829
+ raise ValueError(
830
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
831
+ )
832
+ time_ids = added_cond_kwargs.get("time_ids")
833
+ time_embeds = self.add_time_proj(time_ids.flatten())
834
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
835
+
836
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
837
+ add_embeds = add_embeds.to(emb.dtype)
838
+ aug_emb = self.add_embedding(add_embeds)
839
+
840
+ # Copyright by Qi Xin(2024/07/06)
841
+ # inject control type info to time embedding to distinguish different control conditions
842
+ control_type = added_cond_kwargs.get('control_type')
843
+ control_embeds = self.control_type_proj(control_type.flatten())
844
+ control_embeds = control_embeds.reshape((t_emb.shape[0], -1))
845
+ control_embeds = control_embeds.to(emb.dtype)
846
+ control_emb = self.control_add_embedding(control_embeds)
847
+ emb = emb + control_emb
848
+ #---------------------------------------------------------------------------------
849
+
850
+ emb = emb + aug_emb if aug_emb is not None else emb
851
+
852
+ # 2. pre-process
853
+ sample = self.conv_in(sample)
854
+ indices = torch.nonzero(control_type[0])
855
+
856
+ # Copyright by Qi Xin(2024/07/06)
857
+ # add single/multi conditons to input image.
858
+ # Condition Transformer provides an easy and effective way to fuse different features naturally
859
+ inputs = []
860
+ condition_list = []
861
+
862
+ for idx in range(indices.shape[0] + 1):
863
+ if idx == indices.shape[0]:
864
+ controlnet_cond = sample
865
+ feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) # N * C
866
+ else:
867
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond_list[indices[idx][0]])
868
+ feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) # N * C
869
+ feat_seq = feat_seq + self.task_embedding[indices[idx][0]]
870
+
871
+ inputs.append(feat_seq.unsqueeze(1))
872
+ condition_list.append(controlnet_cond)
873
+
874
+ x = torch.cat(inputs, dim=1) # NxLxC
875
+ x = self.transformer_layes(x)
876
+
877
+ controlnet_cond_fuser = sample * 0.0
878
+ for idx in range(indices.shape[0]):
879
+ alpha = self.spatial_ch_projs(x[:, idx])
880
+ alpha = alpha.unsqueeze(-1).unsqueeze(-1)
881
+ controlnet_cond_fuser += condition_list[idx] + alpha
882
+
883
+ sample = sample + controlnet_cond_fuser
884
+ #-------------------------------------------------------------------------------------------
885
+
886
+ # 3. down
887
+ down_block_res_samples = (sample,)
888
+ for downsample_block in self.down_blocks:
889
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
890
+ sample, res_samples = downsample_block(
891
+ hidden_states=sample,
892
+ temb=emb,
893
+ encoder_hidden_states=encoder_hidden_states,
894
+ attention_mask=attention_mask,
895
+ cross_attention_kwargs=cross_attention_kwargs,
896
+ )
897
+ else:
898
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
899
+
900
+ down_block_res_samples += res_samples
901
+
902
+ # 4. mid
903
+ if self.mid_block is not None:
904
+ sample = self.mid_block(
905
+ sample,
906
+ emb,
907
+ encoder_hidden_states=encoder_hidden_states,
908
+ attention_mask=attention_mask,
909
+ cross_attention_kwargs=cross_attention_kwargs,
910
+ )
911
+
912
+ # 5. Control net blocks
913
+
914
+ controlnet_down_block_res_samples = ()
915
+
916
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
917
+ down_block_res_sample = controlnet_block(down_block_res_sample)
918
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
919
+
920
+ down_block_res_samples = controlnet_down_block_res_samples
921
+
922
+ mid_block_res_sample = self.controlnet_mid_block(sample)
923
+
924
+ # 6. scaling
925
+ if guess_mode and not self.config.global_pool_conditions:
926
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
927
+ scales = scales * conditioning_scale
928
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
929
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
930
+ else:
931
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
932
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
933
+
934
+ if self.config.global_pool_conditions:
935
+ down_block_res_samples = [
936
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
937
+ ]
938
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
939
+
940
+ if not return_dict:
941
+ return (down_block_res_samples, mid_block_res_sample)
942
+
943
+ return ControlNetOutput(
944
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
945
+ )
946
+
947
+
948
+
949
+ def zero_module(module):
950
+ for p in module.parameters():
951
+ nn.init.zeros_(p)
952
+ return module
953
+
954
+
955
+
956
+
957
+
t2i/controlnet_union/pipeline/pipeline_controlnet_union_inpaint_sd_xl.py CHANGED
The diff for this file is too large to render. See raw diff
 
t2i/controlnet_union/pipeline/pipeline_controlnet_union_sd_xl.py CHANGED
The diff for this file is too large to render. See raw diff
 
t2i/controlnet_union/pipeline/pipeline_controlnet_union_sd_xl_img2img.py CHANGED
The diff for this file is too large to render. See raw diff
 
t2i/pipe.py CHANGED
@@ -1,157 +1,157 @@
1
- import os, subprocess, time, datetime, inspect
2
- from typing import Any, Tuple, Dict, List, Optional
3
- from dataclasses import dataclass, field
4
- import torch
5
- from diffusers import DiffusionPipeline, AutoencoderKL
6
- from diffusers.models.attention_processor import AttnProcessor2_0
7
- from t2i_config import models, sdxl_vaes, sd15_vaes, PIPELINE_MAX_GIB
8
- from t2i.utils import (logger, get_token, free_memory, calc_pipe_size, is_weight_url, get_file,
9
- get_model_type, get_model_type_from_pipe, get_task_class, DEFAULT_TASKS, IS_ZEROGPU, DEVICE, DTYPE, IS_QUANT,
10
- MAX_SEED, MAX_IMAGE_SIZE, DEFAULT_MODEL_TYPE, DEFAULT_STR, ASPECT_RATIOS, PIPELINE_TYPES, DEFAULT_VAE, PARAM_MODES)
11
-
12
-
13
- if IS_ZEROGPU:
14
- logger.info("Running on Zero GPU.")
15
- os.environ["ZEROGPU_SIZE"] = "auto" # https://huggingface.co/posts/cbensimon/356529804559377
16
- subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
17
- torch.set_float32_matmul_precision("high") # https://pytorch.org/blog/accelerating-generative-ai-3/
18
- logger.info(f"Using device: {DEVICE}")
19
- logger.info(f"Using dtype: {DTYPE}")
20
-
21
-
22
- #from torchao.quantization.quant_api import Int8WeightOnlyConfig, quantize_
23
- @dataclass(order=True)
24
- class Pipeline:
25
- name: str = ""
26
- pipe: Any = field(default_factory=Any)
27
- lastmod: float = 0.
28
- size: int = 0
29
- type: str = DEFAULT_MODEL_TYPE
30
- pipe_type: str = PIPELINE_TYPES[0]
31
-
32
- def __str__(self):
33
- return f"{self.name} ({type(self.pipe).__name__} {self.type} {self.pipe_type}) Size:{float(self.size) / (1024.**3):.2f}GiB LastMod.:{datetime.datetime.fromtimestamp(self.lastmod).strftime('%Y/%m/%d %H:%M:%S')}"
34
-
35
- def __del__(self):
36
- if not self.pipe: return
37
- self.pipe.to("cpu")
38
- del self.pipe
39
- free_memory()
40
- logger.debug(f"Unloaded pipeline {self.name}.")
41
-
42
- def onload(self, device: str, model_type: str) -> Any:
43
- self.lastmod = time.time()
44
- if device != "cpu" and not IS_QUANT:
45
- if self.pipe.device != device: self.pipe.to(device)
46
- # https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0
47
- #if model_type in ["SD 1.5", "SDXL"]: self.pipe.unet.set_attn_processor(AttnProcessor2_0())
48
- #elif model_type in ["FLUX"]: self.pipe.transformer.set_attn_processor(AttnProcessor2_0())
49
- #self.pipe.vae.set_attn_processor(AttnProcessor2_0())
50
- #logger.debug(f"SDPA enabled {type(self.pipe).__name__} ({model_type}) on {device}.") # by default in PyTorch 2.x
51
- return self.pipe
52
-
53
- def quantize(self):
54
- if not IS_QUANT: return self
55
- #if self.type in ["SD 1.5", "SDXL"]: quantize_(self.pipe.unet, Int8WeightOnlyConfig())
56
- #elif self.type in ["FLUX"]: quantize_(self.pipe.transformer, Int8WeightOnlyConfig())
57
- self.size=calc_pipe_size(self.pipe)
58
- logger.debug(f"Quantized pipeline {self.name}.")
59
- return self
60
-
61
-
62
- class Pipelines:
63
- def __init__(self):
64
- self.pipes: Dict[str, Pipeline] = {}
65
- self.max_gib = PIPELINE_MAX_GIB
66
-
67
- def __call__(self, name: str, device: str="cpu", model_type: str=DEFAULT_MODEL_TYPE, pipe_type: str=PIPELINE_TYPES[0]) -> Any:
68
- try:
69
- if name in self.pipes.keys():
70
- pipe = self.pipes[name].onload(device, model_type)
71
- free_memory()
72
- return pipe
73
- if model_type == DEFAULT_MODEL_TYPE: model_type = get_model_type(name)
74
- pipe_class = get_task_class(model_type, DEFAULT_TASKS[0])
75
- if is_weight_url(name):
76
- path = get_file(name)
77
- if model_type == "SDXL": pipe = pipe_class.from_single_file(path, add_watermarker=False, torch_dtype=DTYPE)
78
- elif model_type == "SD 1.5": pipe = pipe_class.from_single_file(path, torch_dtype=DTYPE)
79
- elif model_type == "FLUX": pipe = pipe_class.from_single_file(path, torch_dtype=DTYPE) #
80
- else: raise Exception(f"Invalid architecture {name}")
81
- else:
82
- if model_type == "SDXL": pipe = pipe_class.from_pretrained(name, add_watermarker=False, torch_dtype=DTYPE)
83
- elif model_type == "SD 1.5": pipe = pipe_class.from_pretrained(name, torch_dtype=DTYPE)
84
- elif model_type == "FLUX": pipe = pipe_class.from_pretrained(name, torch_dtype=DTYPE) #
85
- else:
86
- pipe = pipe_class.from_pretrained(name, torch_dtype=DTYPE)
87
- model_type = get_model_type_from_pipe(pipe)
88
- if pipe_type == "Long Prompt Weighting" and model_type in ["SD 1.5", "SDXL"]:
89
- if model_type == "SD 1.5": pipe = DiffusionPipeline.from_pipe(pipe, custom_pipeline="lpw_stable_diffusion", torch_dtype=DTYPE)
90
- elif model_type == "SDXL": pipe = DiffusionPipeline.from_pipe(pipe, custom_pipeline="lpw_stable_diffusion_xl", add_watermarker=False, torch_dtype=DTYPE)
91
- self.pipes[name] = Pipeline(name=name, pipe=pipe, lastmod=time.time(), size=calc_pipe_size(pipe), type=model_type, pipe_type=pipe_type)#.quantize()
92
- logger.info(f"Loaded {self.pipes[name]}.")
93
- self.clean()
94
- pipe = self.pipes[name].onload(device, model_type)
95
- free_memory()
96
- return pipe
97
- except Exception as e:
98
- logger.info(f"Failed to load pipeline for {name} {e}")
99
- return None
100
-
101
- def get_model_type(self, name: str) -> str:
102
- if name in self.pipes.keys(): return self.pipes[name].type
103
- else: return DEFAULT_MODEL_TYPE
104
-
105
- def __str__(self):
106
- return "\n".join([str(x) for x in self.pipes.values()])
107
-
108
- def clean(self):
109
- items = sorted(list(self.pipes.values()), key=lambda x:x.lastmod, reverse=True)
110
- sum_bytes = 0
111
- max_bytes = self.max_gib * (1024 ** 3)
112
- del_items = []
113
- for i, item in enumerate(items):
114
- sum_bytes += item.size
115
- if sum_bytes > max_bytes and i > 0: del_items.append(item.name)
116
- for item in del_items:
117
- self.pipes.pop(item)
118
- logger.debug(f"Unloaded {item}.")
119
-
120
-
121
- pipes = Pipelines()
122
-
123
-
124
- def get_current_model_type(name: str) -> str:
125
- return pipes.get_model_type(name)
126
-
127
-
128
- VAE_NAMES = [DEFAULT_VAE] + sdxl_vaes + sd15_vaes
129
-
130
-
131
- def get_vae(pipe: Any, name: str, device: str, model_type: str=DEFAULT_MODEL_TYPE):
132
- if name == DEFAULT_VAE or not pipe: return pipe
133
- try:
134
- model_type = get_current_model_type(name)
135
- if (model_type == "SDXL" and name in sd15_vaes) or (model_type == "SD 1.5" and name in sdxl_vaes): return pipe
136
- if is_weight_url(name): vae = AutoencoderKL.from_single_file(get_file(name), torch_dtype=DTYPE)
137
- else: vae = AutoencoderKL.from_pretrained(name, torch_dtype=DTYPE)
138
- if vae:
139
- if device != "cpu" and vae.device != device: vae.to(device)
140
- pipe.vae = vae
141
- logger.info(f"VAE loaded {name}.")
142
- return pipe
143
- except Exception as e:
144
- logger.info(f"{inspect.currentframe().f_code.co_name}: {e}")
145
- return pipe
146
-
147
-
148
- def get_pipe(name: str, device: str="cpu", model_type: str=DEFAULT_MODEL_TYPE, pipe_type: str=PIPELINE_TYPES[0]):
149
- global pipes
150
- try:
151
- pipe = pipes(name, device, model_type, pipe_type)
152
- return pipe
153
- except Exception as e:
154
- logger.info(f"{inspect.currentframe().f_code.co_name}: {e}")
155
- return None
156
- finally:
157
- logger.debug(f"Current pipes: {pipes}")
 
1
+ import os, subprocess, time, datetime, inspect
2
+ from typing import Any, Tuple, Dict, List, Optional
3
+ from dataclasses import dataclass, field
4
+ import torch
5
+ from diffusers import DiffusionPipeline, AutoencoderKL
6
+ from diffusers.models.attention_processor import AttnProcessor2_0
7
+ from t2i_config import models, sdxl_vaes, sd15_vaes, PIPELINE_MAX_GIB
8
+ from t2i.utils import (logger, get_token, free_memory, calc_pipe_size, is_weight_url, get_file,
9
+ get_model_type, get_model_type_from_pipe, get_task_class, DEFAULT_TASKS, IS_ZEROGPU, DEVICE, DTYPE, IS_QUANT,
10
+ MAX_SEED, MAX_IMAGE_SIZE, DEFAULT_MODEL_TYPE, DEFAULT_STR, ASPECT_RATIOS, PIPELINE_TYPES, DEFAULT_VAE, PARAM_MODES)
11
+
12
+
13
+ if IS_ZEROGPU:
14
+ logger.info("Running on Zero GPU.")
15
+ os.environ["ZEROGPU_SIZE"] = "auto" # https://huggingface.co/posts/cbensimon/356529804559377
16
+ subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", shell=True)
17
+ torch.set_float32_matmul_precision("high") # https://pytorch.org/blog/accelerating-generative-ai-3/
18
+ logger.info(f"Using device: {DEVICE}")
19
+ logger.info(f"Using dtype: {DTYPE}")
20
+
21
+
22
+ #from torchao.quantization.quant_api import Int8WeightOnlyConfig, quantize_
23
+ @dataclass(order=True)
24
+ class Pipeline:
25
+ name: str = ""
26
+ pipe: Any = field(default_factory=Any)
27
+ lastmod: float = 0.
28
+ size: int = 0
29
+ type: str = DEFAULT_MODEL_TYPE
30
+ pipe_type: str = PIPELINE_TYPES[0]
31
+
32
+ def __str__(self):
33
+ return f"{self.name} ({type(self.pipe).__name__} {self.type} {self.pipe_type}) Size:{float(self.size) / (1024.**3):.2f}GiB LastMod.:{datetime.datetime.fromtimestamp(self.lastmod).strftime('%Y/%m/%d %H:%M:%S')}"
34
+
35
+ def __del__(self):
36
+ if not self.pipe: return
37
+ self.pipe.to("cpu")
38
+ del self.pipe
39
+ free_memory()
40
+ logger.debug(f"Unloaded pipeline {self.name}.")
41
+
42
+ def onload(self, device: str, model_type: str) -> Any:
43
+ self.lastmod = time.time()
44
+ if device != "cpu" and not IS_QUANT:
45
+ if self.pipe.device != device: self.pipe.to(device)
46
+ # https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0
47
+ #if model_type in ["SD 1.5", "SDXL"]: self.pipe.unet.set_attn_processor(AttnProcessor2_0())
48
+ #elif model_type in ["FLUX"]: self.pipe.transformer.set_attn_processor(AttnProcessor2_0())
49
+ #self.pipe.vae.set_attn_processor(AttnProcessor2_0())
50
+ #logger.debug(f"SDPA enabled {type(self.pipe).__name__} ({model_type}) on {device}.") # by default in PyTorch 2.x
51
+ return self.pipe
52
+
53
+ def quantize(self):
54
+ if not IS_QUANT: return self
55
+ #if self.type in ["SD 1.5", "SDXL"]: quantize_(self.pipe.unet, Int8WeightOnlyConfig())
56
+ #elif self.type in ["FLUX"]: quantize_(self.pipe.transformer, Int8WeightOnlyConfig())
57
+ self.size=calc_pipe_size(self.pipe)
58
+ logger.debug(f"Quantized pipeline {self.name}.")
59
+ return self
60
+
61
+
62
+ class Pipelines:
63
+ def __init__(self):
64
+ self.pipes: Dict[str, Pipeline] = {}
65
+ self.max_gib = PIPELINE_MAX_GIB
66
+
67
+ def __call__(self, name: str, device: str="cpu", model_type: str=DEFAULT_MODEL_TYPE, pipe_type: str=PIPELINE_TYPES[0]) -> Any:
68
+ try:
69
+ if name in self.pipes.keys():
70
+ pipe = self.pipes[name].onload(device, model_type)
71
+ free_memory()
72
+ return pipe
73
+ if model_type == DEFAULT_MODEL_TYPE: model_type = get_model_type(name)
74
+ pipe_class = get_task_class(model_type, DEFAULT_TASKS[0])
75
+ if is_weight_url(name):
76
+ path = get_file(name)
77
+ if model_type == "SDXL": pipe = pipe_class.from_single_file(path, add_watermarker=False, torch_dtype=DTYPE)
78
+ elif model_type == "SD 1.5": pipe = pipe_class.from_single_file(path, torch_dtype=DTYPE)
79
+ elif model_type == "FLUX": pipe = pipe_class.from_single_file(path, torch_dtype=DTYPE) #
80
+ else: raise Exception(f"Invalid architecture {name}")
81
+ else:
82
+ if model_type == "SDXL": pipe = pipe_class.from_pretrained(name, add_watermarker=False, torch_dtype=DTYPE)
83
+ elif model_type == "SD 1.5": pipe = pipe_class.from_pretrained(name, torch_dtype=DTYPE)
84
+ elif model_type == "FLUX": pipe = pipe_class.from_pretrained(name, torch_dtype=DTYPE) #
85
+ else:
86
+ pipe = pipe_class.from_pretrained(name, torch_dtype=DTYPE)
87
+ model_type = get_model_type_from_pipe(pipe)
88
+ if pipe_type == "Long Prompt Weighting" and model_type in ["SD 1.5", "SDXL"]:
89
+ if model_type == "SD 1.5": pipe = DiffusionPipeline.from_pipe(pipe, custom_pipeline="lpw_stable_diffusion", torch_dtype=DTYPE)
90
+ elif model_type == "SDXL": pipe = DiffusionPipeline.from_pipe(pipe, custom_pipeline="lpw_stable_diffusion_xl", add_watermarker=False, torch_dtype=DTYPE)
91
+ self.pipes[name] = Pipeline(name=name, pipe=pipe, lastmod=time.time(), size=calc_pipe_size(pipe), type=model_type, pipe_type=pipe_type)#.quantize()
92
+ logger.info(f"Loaded {self.pipes[name]}.")
93
+ self.clean()
94
+ pipe = self.pipes[name].onload(device, model_type)
95
+ free_memory()
96
+ return pipe
97
+ except Exception as e:
98
+ logger.info(f"Failed to load pipeline for {name} {e}")
99
+ return None
100
+
101
+ def get_model_type(self, name: str) -> str:
102
+ if name in self.pipes.keys(): return self.pipes[name].type
103
+ else: return DEFAULT_MODEL_TYPE
104
+
105
+ def __str__(self):
106
+ return "\n".join([str(x) for x in self.pipes.values()])
107
+
108
+ def clean(self):
109
+ items = sorted(list(self.pipes.values()), key=lambda x:x.lastmod, reverse=True)
110
+ sum_bytes = 0
111
+ max_bytes = self.max_gib * (1024 ** 3)
112
+ del_items = []
113
+ for i, item in enumerate(items):
114
+ sum_bytes += item.size
115
+ if sum_bytes > max_bytes and i > 0: del_items.append(item.name)
116
+ for item in del_items:
117
+ self.pipes.pop(item)
118
+ logger.debug(f"Unloaded {item}.")
119
+
120
+
121
+ pipes = Pipelines()
122
+
123
+
124
+ def get_current_model_type(name: str) -> str:
125
+ return pipes.get_model_type(name)
126
+
127
+
128
+ VAE_NAMES = [DEFAULT_VAE] + sdxl_vaes + sd15_vaes
129
+
130
+
131
+ def get_vae(pipe: Any, name: str, device: str, model_type: str=DEFAULT_MODEL_TYPE):
132
+ if name == DEFAULT_VAE or not pipe: return pipe
133
+ try:
134
+ model_type = get_current_model_type(name)
135
+ if (model_type == "SDXL" and name in sd15_vaes) or (model_type == "SD 1.5" and name in sdxl_vaes): return pipe
136
+ if is_weight_url(name): vae = AutoencoderKL.from_single_file(get_file(name), torch_dtype=DTYPE)
137
+ else: vae = AutoencoderKL.from_pretrained(name, torch_dtype=DTYPE)
138
+ if vae:
139
+ if device != "cpu" and vae.device != device: vae.to(device)
140
+ pipe.vae = vae
141
+ logger.info(f"VAE loaded {name}.")
142
+ return pipe
143
+ except Exception as e:
144
+ logger.info(f"{inspect.currentframe().f_code.co_name}: {e}")
145
+ return pipe
146
+
147
+
148
+ def get_pipe(name: str, device: str="cpu", model_type: str=DEFAULT_MODEL_TYPE, pipe_type: str=PIPELINE_TYPES[0]):
149
+ global pipes
150
+ try:
151
+ pipe = pipes(name, device, model_type, pipe_type)
152
+ return pipe
153
+ except Exception as e:
154
+ logger.info(f"{inspect.currentframe().f_code.co_name}: {e}")
155
+ return None
156
+ finally:
157
+ logger.debug(f"Current pipes: {pipes}")