Update app.py
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ import json
|
|
| 4 |
|
| 5 |
component_filter = ["scheduler", "safety_checker", "tokenizer"]
|
| 6 |
|
|
|
|
| 7 |
def format_size(num: int) -> str:
|
| 8 |
"""Format size in bytes into a human-readable string.
|
| 9 |
Taken from https://stackoverflow.com/a/1094933
|
|
@@ -15,6 +16,7 @@ def format_size(num: int) -> str:
|
|
| 15 |
num_f /= 1000.0
|
| 16 |
return f"{num_f:.1f}Y"
|
| 17 |
|
|
|
|
| 18 |
def format_output(pipeline_id, memory_mapping):
|
| 19 |
markdown_str = f"## {pipeline_id}\n"
|
| 20 |
if memory_mapping:
|
|
@@ -22,12 +24,14 @@ def format_output(pipeline_id, memory_mapping):
|
|
| 22 |
markdown_str += f"* {component}: {format_size(memory)}\n"
|
| 23 |
return markdown_str
|
| 24 |
|
|
|
|
| 25 |
def load_model_index(pipeline_id, token=None, revision=None):
|
| 26 |
index_path = hf_hub_download(repo_id=pipeline_id, filename="model_index.json", revision=revision, token=token)
|
| 27 |
with open(index_path, "r") as f:
|
| 28 |
index_dict = json.load(f)
|
| 29 |
return index_dict
|
| 30 |
|
|
|
|
| 31 |
def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=None, extension=".safetensors"):
|
| 32 |
if token == "":
|
| 33 |
token = None
|
|
@@ -48,17 +52,22 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
|
|
| 48 |
index_filter.extend(["_class_name", "_diffusers_version"])
|
| 49 |
for current_component in index_dict:
|
| 50 |
if current_component not in index_filter:
|
| 51 |
-
current_component_fileobjs =
|
| 52 |
if current_component_fileobjs:
|
| 53 |
current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs]
|
| 54 |
-
condition =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
variant_present_with_extension = any(condition(filename) for filename in current_component_filenames)
|
| 56 |
if not variant_present_with_extension:
|
| 57 |
-
raise ValueError(
|
|
|
|
|
|
|
| 58 |
else:
|
| 59 |
raise ValueError(f"Problem with {current_component}.")
|
| 60 |
|
| 61 |
-
|
| 62 |
# Handle text encoder separately when it's sharded.
|
| 63 |
is_text_encoder_shared = any(".index.json" in file_obj.rfilename for file_obj in files_in_repo)
|
| 64 |
component_wise_memory = {}
|
|
@@ -99,4 +108,37 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
|
|
| 99 |
print(selected_file.rfilename)
|
| 100 |
component_wise_memory[component] = selected_file.size
|
| 101 |
|
| 102 |
-
return format_output(pipeline_id, component_wise_memory)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
component_filter = ["scheduler", "safety_checker", "tokenizer"]
|
| 6 |
|
| 7 |
+
|
| 8 |
def format_size(num: int) -> str:
|
| 9 |
"""Format size in bytes into a human-readable string.
|
| 10 |
Taken from https://stackoverflow.com/a/1094933
|
|
|
|
| 16 |
num_f /= 1000.0
|
| 17 |
return f"{num_f:.1f}Y"
|
| 18 |
|
| 19 |
+
|
| 20 |
def format_output(pipeline_id, memory_mapping):
|
| 21 |
markdown_str = f"## {pipeline_id}\n"
|
| 22 |
if memory_mapping:
|
|
|
|
| 24 |
markdown_str += f"* {component}: {format_size(memory)}\n"
|
| 25 |
return markdown_str
|
| 26 |
|
| 27 |
+
|
| 28 |
def load_model_index(pipeline_id, token=None, revision=None):
|
| 29 |
index_path = hf_hub_download(repo_id=pipeline_id, filename="model_index.json", revision=revision, token=token)
|
| 30 |
with open(index_path, "r") as f:
|
| 31 |
index_dict = json.load(f)
|
| 32 |
return index_dict
|
| 33 |
|
| 34 |
+
|
| 35 |
def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=None, extension=".safetensors"):
|
| 36 |
if token == "":
|
| 37 |
token = None
|
|
|
|
| 52 |
index_filter.extend(["_class_name", "_diffusers_version"])
|
| 53 |
for current_component in index_dict:
|
| 54 |
if current_component not in index_filter:
|
| 55 |
+
current_component_fileobjs = list(filter(lambda x: current_component in x.rfilename, files_in_repo))
|
| 56 |
if current_component_fileobjs:
|
| 57 |
current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs]
|
| 58 |
+
condition = (
|
| 59 |
+
lambda filename: extension in filename and variant in filename
|
| 60 |
+
if variant is not None
|
| 61 |
+
else lambda filename: extension in filename
|
| 62 |
+
)
|
| 63 |
variant_present_with_extension = any(condition(filename) for filename in current_component_filenames)
|
| 64 |
if not variant_present_with_extension:
|
| 65 |
+
raise ValueError(
|
| 66 |
+
f"Requested extension ({extension}) and variant ({variant}) not present for {current_component}. Available files for this component:\n{current_component_filenames}."
|
| 67 |
+
)
|
| 68 |
else:
|
| 69 |
raise ValueError(f"Problem with {current_component}.")
|
| 70 |
|
|
|
|
| 71 |
# Handle text encoder separately when it's sharded.
|
| 72 |
is_text_encoder_shared = any(".index.json" in file_obj.rfilename for file_obj in files_in_repo)
|
| 73 |
component_wise_memory = {}
|
|
|
|
| 108 |
print(selected_file.rfilename)
|
| 109 |
component_wise_memory[component] = selected_file.size
|
| 110 |
|
| 111 |
+
return format_output(pipeline_id, component_wise_memory)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
gr.Interface(
|
| 115 |
+
title="Compute component-wise memory of a 🧨 Diffusers pipeline.",
|
| 116 |
+
description="Sizes will be reported in GB. Pipelines containing text encoders with sharded checkpoints are also supported (PixArt-Alpha, for example) 🤗",
|
| 117 |
+
fn=get_component_wise_memory,
|
| 118 |
+
inputs=[
|
| 119 |
+
gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"),
|
| 120 |
+
gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
|
| 121 |
+
gr.components.Dropdown(
|
| 122 |
+
[
|
| 123 |
+
"fp32",
|
| 124 |
+
"fp16",
|
| 125 |
+
],
|
| 126 |
+
label="variant",
|
| 127 |
+
info="Precision to use for calculation.",
|
| 128 |
+
),
|
| 129 |
+
gr.components.Textbox(lines=1, label="revision", info="Repository revision to use."),
|
| 130 |
+
gr.components.Dropdown(
|
| 131 |
+
[".bin", ".safetensors"],
|
| 132 |
+
label="extension",
|
| 133 |
+
info="Extension to use.",
|
| 134 |
+
),
|
| 135 |
+
],
|
| 136 |
+
outputs=[gr.Markdown(label="Output")],
|
| 137 |
+
examples=[
|
| 138 |
+
["runwayml/stable-diffusion-v1-5", None, "fp32", None, ".safetensors"],
|
| 139 |
+
["stabilityai/stable-diffusion-xl-base-1.0", None, "fp16", None, ".safetensors"],
|
| 140 |
+
["PixArt-alpha/PixArt-XL-2-1024-MS", None, "fp32", None, ".safetensors"],
|
| 141 |
+
],
|
| 142 |
+
theme=gr.themes.Soft(),
|
| 143 |
+
allow_flagging=False,
|
| 144 |
+
).launch(show_error=True)
|