reobustify
Browse files
app.py
CHANGED
|
@@ -29,6 +29,8 @@ For further clarification on this topic, feel free to open a [discussion](https:
|
|
| 29 |
`controlnet_id` or `t2i_adapter_id`.
|
| 30 |
"""
|
| 31 |
|
|
|
|
|
|
|
| 32 |
|
| 33 |
def format_size(num: int) -> str:
|
| 34 |
"""Format size in bytes into a human-readable string.
|
|
@@ -69,13 +71,12 @@ def load_model_index(pipeline_id, token=None, revision=None):
|
|
| 69 |
|
| 70 |
def get_individual_model_memory(id, token, variant, extension):
|
| 71 |
files_in_repo = model_info(id, token=token, files_metadata=True).siblings
|
| 72 |
-
for x in files_in_repo
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
return x.size
|
| 79 |
|
| 80 |
|
| 81 |
def get_component_wise_memory(
|
|
@@ -211,7 +212,7 @@ with gr.Interface(
|
|
| 211 |
gr.components.Textbox(lines=1, label="t2i_adapter_id", info="Example: TencentARC/t2iadapter_color_sd14v1"),
|
| 212 |
gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
|
| 213 |
gr.components.Radio(
|
| 214 |
-
|
| 215 |
label="variant",
|
| 216 |
info="Precision to use for calculation.",
|
| 217 |
),
|
|
@@ -232,7 +233,7 @@ with gr.Interface(
|
|
| 232 |
None,
|
| 233 |
"TencentARC/t2i-adapter-lineart-sdxl-1.0,TencentARC/t2i-adapter-canny-sdxl-1.0",
|
| 234 |
None,
|
| 235 |
-
"
|
| 236 |
None,
|
| 237 |
".safetensors",
|
| 238 |
],
|
|
|
|
| 29 |
`controlnet_id` or `t2i_adapter_id`.
|
| 30 |
"""
|
| 31 |
|
| 32 |
+
ALLOWED_VARIANTS = ["fp32", "fp16", "bf16"]
|
| 33 |
+
|
| 34 |
|
| 35 |
def format_size(num: int) -> str:
|
| 36 |
"""Format size in bytes into a human-readable string.
|
|
|
|
| 71 |
|
| 72 |
def get_individual_model_memory(id, token, variant, extension):
|
| 73 |
files_in_repo = model_info(id, token=token, files_metadata=True).siblings
|
| 74 |
+
candidates = [x for x in files_in_repo if extension in x.rfilename]
|
| 75 |
+
if variant:
|
| 76 |
+
candidate = list(filter(lambda x: variant in x.rfilename, candidates))[0]
|
| 77 |
+
else:
|
| 78 |
+
candidate = list(filter(lambda x: all(var not in x.rfilename for var in ALLOWED_VARIANTS[1:]), candidates))[0]
|
| 79 |
+
return candidate.size
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
def get_component_wise_memory(
|
|
|
|
| 212 |
gr.components.Textbox(lines=1, label="t2i_adapter_id", info="Example: TencentARC/t2iadapter_color_sd14v1"),
|
| 213 |
gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
|
| 214 |
gr.components.Radio(
|
| 215 |
+
ALLOWED_VARIANTS,
|
| 216 |
label="variant",
|
| 217 |
info="Precision to use for calculation.",
|
| 218 |
),
|
|
|
|
| 233 |
None,
|
| 234 |
"TencentARC/t2i-adapter-lineart-sdxl-1.0,TencentARC/t2i-adapter-canny-sdxl-1.0",
|
| 235 |
None,
|
| 236 |
+
"fp16",
|
| 237 |
None,
|
| 238 |
".safetensors",
|
| 239 |
],
|