add support for controlnet and t2i adapter too
Browse files
app.py
CHANGED
|
@@ -11,6 +11,24 @@ COMPONENT_FILTER = [
|
|
| 11 |
"_diffusers_version",
|
| 12 |
]
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
def format_size(num: int) -> str:
|
| 16 |
"""Format size in bytes into a human-readable string.
|
|
@@ -24,11 +42,21 @@ def format_size(num: int) -> str:
|
|
| 24 |
return f"{num_f:.1f}Y"
|
| 25 |
|
| 26 |
|
| 27 |
-
def format_output(pipeline_id, memory_mapping):
|
| 28 |
markdown_str = f"## {pipeline_id}\n"
|
|
|
|
| 29 |
if memory_mapping:
|
| 30 |
for component, memory in memory_mapping.items():
|
| 31 |
markdown_str += f"* {component}: {format_size(memory)}\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
return markdown_str
|
| 33 |
|
| 34 |
|
|
@@ -39,7 +67,35 @@ def load_model_index(pipeline_id, token=None, revision=None):
|
|
| 39 |
return index_dict
|
| 40 |
|
| 41 |
|
| 42 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
if token == "":
|
| 44 |
token = None
|
| 45 |
|
|
@@ -49,12 +105,31 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
|
|
| 49 |
if variant == "fp32":
|
| 50 |
variant = None
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
print(f"pipeline_id: {pipeline_id}, variant: {variant}, revision: {revision}, extension: {extension}")
|
| 53 |
|
|
|
|
| 54 |
files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
|
| 55 |
index_dict = load_model_index(pipeline_id, token=token, revision=revision)
|
| 56 |
|
| 57 |
-
# Check if all the concerned components have the checkpoints in
|
|
|
|
| 58 |
print(f"Index dict: {index_dict}")
|
| 59 |
for current_component in index_dict:
|
| 60 |
if (
|
|
@@ -63,6 +138,7 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
|
|
| 63 |
and len(index_dict[current_component]) == 2
|
| 64 |
):
|
| 65 |
current_component_fileobjs = list(filter(lambda x: current_component in x.rfilename, files_in_repo))
|
|
|
|
| 66 |
if current_component_fileobjs:
|
| 67 |
current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs]
|
| 68 |
condition = ( # noqa: E731
|
|
@@ -119,16 +195,20 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
|
|
| 119 |
if selected_file is not None:
|
| 120 |
component_wise_memory[component] = selected_file.size
|
| 121 |
|
| 122 |
-
return format_output(pipeline_id, component_wise_memory)
|
| 123 |
|
| 124 |
|
| 125 |
with gr.Interface(
|
| 126 |
title="Compute component-wise memory of a 🧨 Diffusers pipeline.",
|
| 127 |
description="Pipelines containing text encoders with sharded checkpoints are also supported"
|
| 128 |
-
" (PixArt-Alpha, for example) 🤗"
|
|
|
|
|
|
|
| 129 |
fn=get_component_wise_memory,
|
| 130 |
inputs=[
|
| 131 |
gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"),
|
|
|
|
|
|
|
| 132 |
gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
|
| 133 |
gr.components.Radio(
|
| 134 |
["fp32", "fp16", "bf16"],
|
|
@@ -144,11 +224,20 @@ with gr.Interface(
|
|
| 144 |
],
|
| 145 |
outputs=[gr.Markdown(label="Output")],
|
| 146 |
examples=[
|
| 147 |
-
["runwayml/stable-diffusion-v1-5", None, "fp32", None, ".safetensors"],
|
| 148 |
-
["
|
| 149 |
-
["
|
| 150 |
-
[
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
],
|
| 153 |
theme=gr.themes.Soft(),
|
| 154 |
allow_flagging="never",
|
|
|
|
| 11 |
"_diffusers_version",
|
| 12 |
]
|
| 13 |
|
| 14 |
+
ARTICLE = """
|
| 15 |
+
## Notes on how to use the `controlnet_id` and `t2i_adapter_id` fields
|
| 16 |
+
|
| 17 |
+
Both `controlnet_id` and `t2i_adapter_id` fields support passing multiple checkpoint ids,
|
| 18 |
+
e.g., "thibaud/controlnet-openpose-sdxl-1.0,diffusers/controlnet-canny-sdxl-1.0". For
|
| 19 |
+
`t2i_adapter_id`, this could be like - "TencentARC/t2iadapter_keypose_sd14v1,TencentARC/t2iadapter_depth_sd14v1".
|
| 20 |
+
|
| 21 |
+
Users should take care of passing the underlying base `pipeline_id` appropriately. For example,
|
| 22 |
+
passing `pipeline_id` as "runwayml/stable-diffusion-v1-5" and `controlnet_id` as "thibaud/controlnet-openpose-sdxl-1.0"
|
| 23 |
+
won't result in an error but these two things aren't meant to compatible. You should pass
|
| 24 |
+
a `controlnet_id` that is compatible with "runwayml/stable-diffusion-v1-5".
|
| 25 |
+
|
| 26 |
+
For further clarification on this topic, feel free to open a [discussion](https://huggingface.co/spaces/diffusers/compute-pipeline-size/discussions).
|
| 27 |
+
|
| 28 |
+
📔 Also, note that `revision` field is only reserved for `pipeline_id`. It won't have any effect on the
|
| 29 |
+
`controlnet_id` or `t2i_adapter_id`.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
|
| 33 |
def format_size(num: int) -> str:
|
| 34 |
"""Format size in bytes into a human-readable string.
|
|
|
|
| 42 |
return f"{num_f:.1f}Y"
|
| 43 |
|
| 44 |
|
| 45 |
+
def format_output(pipeline_id, memory_mapping, controlnet_mapping=None, t2i_adapter_mapping=None):
|
| 46 |
markdown_str = f"## {pipeline_id}\n"
|
| 47 |
+
|
| 48 |
if memory_mapping:
|
| 49 |
for component, memory in memory_mapping.items():
|
| 50 |
markdown_str += f"* {component}: {format_size(memory)}\n"
|
| 51 |
+
if controlnet_mapping:
|
| 52 |
+
markdown_str += "\n## ControlNet(s)\n"
|
| 53 |
+
for controlnet_id, memory in controlnet_mapping.items():
|
| 54 |
+
markdown_str += f"* {controlnet_id}: {format_size(memory)}\n"
|
| 55 |
+
if t2i_adapter_mapping:
|
| 56 |
+
markdown_str += "\n## T2I-Adapters(s)\n"
|
| 57 |
+
for t2_adapter_id, memory in t2i_adapter_mapping.items():
|
| 58 |
+
markdown_str += f"* {t2_adapter_id}: {format_size(memory)}\n"
|
| 59 |
+
|
| 60 |
return markdown_str
|
| 61 |
|
| 62 |
|
|
|
|
| 67 |
return index_dict
|
| 68 |
|
| 69 |
|
| 70 |
+
def get_individual_model_memory(id, token, variant, extension):
|
| 71 |
+
files_in_repo = model_info(id, token=token, files_metadata=True).siblings
|
| 72 |
+
for x in files_in_repo:
|
| 73 |
+
if extension in x.rfilename:
|
| 74 |
+
if variant:
|
| 75 |
+
if variant in x.rfilename:
|
| 76 |
+
return x.size
|
| 77 |
+
else:
|
| 78 |
+
return x.size
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_component_wise_memory(
|
| 82 |
+
pipeline_id,
|
| 83 |
+
controlnet_id=None,
|
| 84 |
+
t2i_adapter_id=None,
|
| 85 |
+
token=None,
|
| 86 |
+
variant=None,
|
| 87 |
+
revision=None,
|
| 88 |
+
extension=".safetensors",
|
| 89 |
+
):
|
| 90 |
+
if controlnet_id == "":
|
| 91 |
+
controlnet_id = None
|
| 92 |
+
|
| 93 |
+
if t2i_adapter_id == "":
|
| 94 |
+
t2i_adapter_id = None
|
| 95 |
+
|
| 96 |
+
if controlnet_id and t2i_adapter_id:
|
| 97 |
+
raise ValueError("Both `controlnet_id` and `t2i_adapter_id` cannot be provided.")
|
| 98 |
+
|
| 99 |
if token == "":
|
| 100 |
token = None
|
| 101 |
|
|
|
|
| 105 |
if variant == "fp32":
|
| 106 |
variant = None
|
| 107 |
|
| 108 |
+
# Handle ControlNet and T2I-Adapter.
|
| 109 |
+
controlnet_mapping = t2_adapter_mapping = None
|
| 110 |
+
if controlnet_id is not None:
|
| 111 |
+
controlnet_id = controlnet_id.split(",")
|
| 112 |
+
controlnet_sizes = [
|
| 113 |
+
get_individual_model_memory(id_, token=token, variant=variant, extension=extension)
|
| 114 |
+
for id_ in controlnet_id
|
| 115 |
+
]
|
| 116 |
+
controlnet_mapping = dict(zip(controlnet_id, controlnet_sizes))
|
| 117 |
+
elif t2i_adapter_id is not None:
|
| 118 |
+
t2i_adapter_id = t2i_adapter_id.split(",")
|
| 119 |
+
t2i_adapter_sizes = [
|
| 120 |
+
get_individual_model_memory(id_, token=token, variant=variant, extension=extension)
|
| 121 |
+
for id_ in t2i_adapter_id
|
| 122 |
+
]
|
| 123 |
+
t2_adapter_mapping = dict(zip(t2i_adapter_id, t2i_adapter_sizes))
|
| 124 |
+
|
| 125 |
print(f"pipeline_id: {pipeline_id}, variant: {variant}, revision: {revision}, extension: {extension}")
|
| 126 |
|
| 127 |
+
# Load pipeline metadata.
|
| 128 |
files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
|
| 129 |
index_dict = load_model_index(pipeline_id, token=token, revision=revision)
|
| 130 |
|
| 131 |
+
# Check if all the concerned components have the checkpoints in
|
| 132 |
+
# the requested "variant" and "extension".
|
| 133 |
print(f"Index dict: {index_dict}")
|
| 134 |
for current_component in index_dict:
|
| 135 |
if (
|
|
|
|
| 138 |
and len(index_dict[current_component]) == 2
|
| 139 |
):
|
| 140 |
current_component_fileobjs = list(filter(lambda x: current_component in x.rfilename, files_in_repo))
|
| 141 |
+
|
| 142 |
if current_component_fileobjs:
|
| 143 |
current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs]
|
| 144 |
condition = ( # noqa: E731
|
|
|
|
| 195 |
if selected_file is not None:
|
| 196 |
component_wise_memory[component] = selected_file.size
|
| 197 |
|
| 198 |
+
return format_output(pipeline_id, component_wise_memory, controlnet_mapping, t2_adapter_mapping)
|
| 199 |
|
| 200 |
|
| 201 |
with gr.Interface(
|
| 202 |
title="Compute component-wise memory of a 🧨 Diffusers pipeline.",
|
| 203 |
description="Pipelines containing text encoders with sharded checkpoints are also supported"
|
| 204 |
+
" (PixArt-Alpha, for example) 🤗 See instructions below the form on how to pass"
|
| 205 |
+
" `controlnet_id` or `t2_adapter_id`.",
|
| 206 |
+
article=ARTICLE,
|
| 207 |
fn=get_component_wise_memory,
|
| 208 |
inputs=[
|
| 209 |
gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"),
|
| 210 |
+
gr.components.Textbox(lines=1, label="controlnet_id", info="Example: lllyasviel/sd-controlnet-canny"),
|
| 211 |
+
gr.components.Textbox(lines=1, label="t2i_adapter_id", info="Example: TencentARC/t2iadapter_color_sd14v1"),
|
| 212 |
gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
|
| 213 |
gr.components.Radio(
|
| 214 |
["fp32", "fp16", "bf16"],
|
|
|
|
| 224 |
],
|
| 225 |
outputs=[gr.Markdown(label="Output")],
|
| 226 |
examples=[
|
| 227 |
+
["runwayml/stable-diffusion-v1-5", None, None, None, "fp32", None, ".safetensors"],
|
| 228 |
+
["PixArt-alpha/PixArt-XL-2-1024-MS", None, None, None, "fp32", None, ".safetensors"],
|
| 229 |
+
["runwayml/stable-diffusion-v1-5", "lllyasviel/sd-controlnet-canny", None, None, "fp32", None, ".safetensors"],
|
| 230 |
+
[
|
| 231 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
| 232 |
+
None,
|
| 233 |
+
"TencentARC/t2i-adapter-lineart-sdxl-1.0,TencentARC/t2i-adapter-canny-sdxl-1.0",
|
| 234 |
+
None,
|
| 235 |
+
"fp32",
|
| 236 |
+
None,
|
| 237 |
+
".safetensors",
|
| 238 |
+
],
|
| 239 |
+
["stabilityai/stable-cascade", None, None, None, "bf16", None, ".safetensors"],
|
| 240 |
+
["Deci/DeciDiffusion-v2-0", None, None, None, "fp32", None, ".safetensors"],
|
| 241 |
],
|
| 242 |
theme=gr.themes.Soft(),
|
| 243 |
allow_flagging="never",
|