| import torch |
| from diffusers import ( |
| DDPMScheduler, |
| DDIMScheduler, |
| PNDMScheduler, |
| LMSDiscreteScheduler, |
| EulerAncestralDiscreteScheduler, |
| EulerDiscreteScheduler, |
| DPMSolverMultistepScheduler, |
| ) |
| import base64 |
|
|
| def get_variant(str_variant): |
| |
| if str(str_variant).lower() == 'none': |
| return None |
| else: |
| return str_variant |
|
|
| def get_bool(str_bool): |
| |
| if str(str_bool).lower() == 'false': |
| return False |
| else: |
| return True |
| |
|
|
| def get_data_type(str_data_type): |
| |
| if str_data_type == "bfloat16": |
| return torch.bfloat16 |
| if str_data_type == "float32": |
| return torch.float32 |
| else: |
| return torch.float16 |
|
|
| def get_tensorfloat32(allow_tensorfloat32): |
| |
| return True if str(allow_tensorfloat32).lower() == 'true' else False |
|
|
|
|
| def get_pipeline(config): |
|
|
| if "pipeline" in config and config["pipeline"] == "StableDiffusion3Pipeline": |
| from diffusers import StableDiffusion3Pipeline |
| pipeline = StableDiffusion3Pipeline.from_pretrained( |
| config["model"], |
| use_safetensors = get_bool(config["use_safetensors"]), |
| torch_dtype = get_data_type(config["data_type"]), |
| variant = get_variant(config["variant"])).to(config["device"]) |
| else: |
| from diffusers import DiffusionPipeline |
| pipeline = DiffusionPipeline.from_pretrained( |
| config["model"], |
| use_safetensors = get_bool(config["use_safetensors"]), |
| torch_dtype = get_data_type(config["data_type"]), |
| variant = get_variant(config["variant"]) |
| ).to(config["device"]) |
|
|
| return pipeline |
|
|
|
|
| def get_scheduler(scheduler, pipeline_config): |
| |
| if scheduler == "DDPMScheduler": |
| return DDPMScheduler.from_config(pipeline_config) |
| elif scheduler == "DDIMScheduler": |
| return DDIMScheduler.from_config(pipeline_config) |
| elif scheduler == "PNDMScheduler": |
| return PNDMScheduler.from_config(pipeline_config) |
| elif scheduler == "LMSDiscreteScheduler": |
| return LMSDiscreteScheduler.from_config(pipeline_config) |
| elif scheduler == "EulerAncestralDiscreteScheduler": |
| return EulerAncestralDiscreteScheduler.from_config(pipeline_config) |
| elif scheduler == "EulerDiscreteScheduler": |
| return EulerDiscreteScheduler.from_config(pipeline_config) |
| elif scheduler == "DPMSolverMultistepScheduler": |
| return DPMSolverMultistepScheduler.from_config(pipeline_config) |
| else: |
| return DPMSolverMultistepScheduler.from_config(pipeline_config) |
|
|
| def dict_list_to_markdown_table(config_history): |
|
|
| if not config_history: |
| return "" |
|
|
| headers = list(config_history[0].keys()) |
| markdown_table = "| share | " + " | ".join(headers) + " |\n" |
| markdown_table += "| --- | " + " | ".join(["---"] * len(headers)) + " |\n" |
|
|
| for index, config in enumerate(config_history): |
| |
| encoded_config = base64.b64encode(str(config).encode()).decode() |
| share_link = f'<a target="_blank" href="?config={encoded_config}">📎</a>' |
| markdown_table += f"| {share_link} | " + " | ".join(str(config.get(key, "")) for key in headers) + " |\n" |
|
|
| markdown_table = '<div style="overflow-x: auto;">\n\n' + markdown_table + '</div>' |
|
|
| return markdown_table |
|
|
|
|